FST  openfst-1.8.4
OpenFst Library
float-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Float weight set and associated semiring operation definitions.
19 
20 #ifndef FST_FLOAT_WEIGHT_H_
21 #define FST_FLOAT_WEIGHT_H_
22 
23 #include <algorithm>
24 #include <climits>
25 #include <cmath>
26 #include <cstddef>
27 #include <cstdint>
28 #include <cstdlib>
29 #include <cstring>
30 #include <ios>
31 #include <istream>
32 #include <limits>
33 #include <ostream>
34 #include <random>
35 #include <sstream>
36 #include <string>
37 #include <type_traits>
38 
39 #include <fst/log.h>
40 #include <fst/util.h>
41 #include <fst/weight.h>
42 #include <fst/compat.h>
43 #include <string_view>
44 
45 namespace fst {
46 
47 namespace internal {
48 // `std::isnan` is not `constexpr` until C++23.
49 // TODO(wolfsonkin): Replace with `std::isnan` when C++23 can be used.
50 template <class T>
51 inline constexpr bool IsNan(T value) {
52  return value != value;
53 }
54 } // namespace internal
55 
56 // Numeric limits class.
57 template <class T>
58 class FloatLimits {
59  public:
60  static constexpr T PosInfinity() {
61  return std::numeric_limits<T>::infinity();
62  }
63 
64  static constexpr T NegInfinity() { return -PosInfinity(); }
65 
66  static constexpr T NumberBad() { return std::numeric_limits<T>::quiet_NaN(); }
67 };
68 
69 // Weight class to be templated on floating-points types.
70 template <class T = float>
72  public:
73  using ValueType = T;
74 
75  FloatWeightTpl() noexcept = default;
76 
77  constexpr FloatWeightTpl(T f) : value_(f) {} // NOLINT
78 
79  std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); }
80 
81  std::ostream &Write(std::ostream &strm) const {
82  return WriteType(strm, value_);
83  }
84 
85  size_t Hash() const {
86  size_t hash = 0;
87  // Avoid using union, which would be undefined behavior.
88  // Use memcpy, similar to bit_cast, but sizes may be different.
89  // This should be optimized into a single move instruction by
90  // any reasonable compiler.
91  std::memcpy(&hash, &value_, std::min(sizeof(hash), sizeof(value_)));
92  return hash;
93  }
94 
95  constexpr const T &Value() const { return value_; }
96 
97  protected:
98  void SetValue(const T &f) { value_ = f; }
99 
100  static constexpr std::string_view GetPrecisionString() {
101  return sizeof(T) == 4 ? ""
102  : sizeof(T) == 1 ? "8"
103  : sizeof(T) == 2 ? "16"
104  : sizeof(T) == 8 ? "64"
105  : "unknown";
106  }
107 
108  private:
109  T value_;
110 };
111 
112 // Single-precision float weight.
114 
115 template <class T>
116 constexpr bool operator==(const FloatWeightTpl<T> &w1,
117  const FloatWeightTpl<T> &w2) {
118 #if (defined(__i386__) || defined(__x86_64__)) && !defined(__SSE2_MATH__)
119 // With i387 instructions, excess precision on a weight in an 80-bit
120 // register may cause it to compare unequal to that same weight when
121 // stored to memory. This breaks =='s reflexivity, in turn breaking
122 // NaturalLess.
123 #error "Please compile with -msse -mfpmath=sse, or equivalent."
124 #endif
125  return w1.Value() == w2.Value();
126 }
127 
128 // These seemingly unnecessary overloads are actually needed to make
129 // comparisons like FloatWeightTpl<float> == float compile. If only the
130 // templated version exists, the FloatWeightTpl<float>(float) conversion
131 // won't be found.
132 constexpr bool operator==(const FloatWeightTpl<float> &w1,
133  const FloatWeightTpl<float> &w2) {
134  return operator==<float>(w1, w2);
135 }
136 
137 constexpr bool operator==(const FloatWeightTpl<double> &w1,
138  const FloatWeightTpl<double> &w2) {
139  return operator==<double>(w1, w2);
140 }
141 
142 template <class T>
143 constexpr bool operator!=(const FloatWeightTpl<T> &w1,
144  const FloatWeightTpl<T> &w2) {
145  return !(w1 == w2);
146 }
147 
148 constexpr bool operator!=(const FloatWeightTpl<float> &w1,
149  const FloatWeightTpl<float> &w2) {
150  return operator!=<float>(w1, w2);
151 }
152 
153 constexpr bool operator!=(const FloatWeightTpl<double> &w1,
154  const FloatWeightTpl<double> &w2) {
155  return operator!=<double>(w1, w2);
156 }
157 
158 template <class T>
159 constexpr bool FloatApproxEqual(T w1, T w2, float delta = kDelta) {
160  return w1 <= w2 + delta && w2 <= w1 + delta;
161 }
162 
163 template <class T>
164 constexpr bool ApproxEqual(const FloatWeightTpl<T> &w1,
165  const FloatWeightTpl<T> &w2, float delta = kDelta) {
166  return FloatApproxEqual(w1.Value(), w2.Value(), delta);
167 }
168 
169 template <class T>
170 inline std::ostream &operator<<(std::ostream &strm,
171  const FloatWeightTpl<T> &w) {
172  if (w.Value() == FloatLimits<T>::PosInfinity()) {
173  return strm << "Infinity";
174  } else if (w.Value() == FloatLimits<T>::NegInfinity()) {
175  return strm << "-Infinity";
176  } else if (internal::IsNan(w.Value())) {
177  return strm << "BadNumber";
178  } else {
179  return strm << w.Value();
180  }
181 }
182 
183 template <class T>
184 inline std::istream &operator>>(std::istream &strm, FloatWeightTpl<T> &w) {
185  std::string s;
186  strm >> s;
187  if (s == "Infinity") {
189  } else if (s == "-Infinity") {
191  } else {
192  char *p;
193  T f = strtod(s.c_str(), &p);
194  if (p < s.c_str() + s.size()) {
195  strm.clear(std::ios::badbit);
196  } else {
197  w = FloatWeightTpl<T>(f);
198  }
199  }
200  return strm;
201 }
202 
203 // Tropical semiring: (min, +, inf, 0).
204 template <class T>
206  public:
207  using typename FloatWeightTpl<T>::ValueType;
211 
212  TropicalWeightTpl() noexcept : FloatWeightTpl<T>() {}
213 
214  constexpr TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
215 
216  static constexpr TropicalWeightTpl<T> Zero() { return Limits::PosInfinity(); }
217 
218  static constexpr TropicalWeightTpl<T> One() { return 0; }
219 
220  static constexpr TropicalWeightTpl<T> NoWeight() {
221  return Limits::NumberBad();
222  }
223 
224  static const std::string &Type() {
225  static const std::string *const type = new std::string(
227  return *type;
228  }
229 
230  constexpr bool Member() const {
231  // All floating point values except for NaNs and negative infinity are valid
232  // tropical weights.
233  //
234  // Testing membership of a given value can be done by simply checking that
235  // it is strictly greater than negative infinity, which fails for negative
236  // infinity itself but also for NaNs. This can usually be accomplished in a
237  // single instruction (such as *UCOMI* on x86) without branching logic.
238  //
239  // An additional wrinkle involves constexpr correctness of floating point
240  // comparisons against NaN. GCC is uneven when it comes to which expressions
241  // it considers compile-time constants. In particular, current versions of
242  // GCC do not always consider (nan < inf) to be a constant expression, but
243  // do consider (inf < nan) to be a constant expression. (See
244  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88173 and
245  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88683 for details.) In order
246  // to allow Member() to be a constexpr function accepted by GCC, we write
247  // the comparison here as (-inf < v).
248  return Limits::NegInfinity() < Value();
249  }
250 
251  TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
252  if (!Member() || Value() == Limits::PosInfinity()) {
253  return *this;
254  } else {
255  return TropicalWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
256  }
257  }
258 
259  constexpr TropicalWeightTpl<T> Reverse() const { return *this; }
260 
261  static constexpr uint64_t Properties() {
263  }
264 };
265 
266 // Single precision tropical weight.
268 
269 template <class T>
271  const TropicalWeightTpl<T> &w2) {
272  return (!w1.Member() || !w2.Member()) ? TropicalWeightTpl<T>::NoWeight()
273  : w1.Value() < w2.Value() ? w1
274  : w2;
275 }
276 
277 // See comment at operator==(FloatWeightTpl<float>, FloatWeightTpl<float>)
278 // for why these overloads are present.
280  const TropicalWeightTpl<float> &w2) {
281  return Plus<float>(w1, w2);
282 }
283 
285  const TropicalWeightTpl<double> &w2) {
286  return Plus<double>(w1, w2);
287 }
288 
289 template <class T>
291  const TropicalWeightTpl<T> &w2) {
292  // The following is safe in the context of the Tropical (and Log) semiring
293  // for all IEEE floating point values, including infinities and NaNs,
294  // because:
295  //
296  // * If one or both of the floating point Values is NaN and hence not a
297  // Member, the result of addition below is NaN, so the result is not a
298  // Member. This supersedes all other cases, so we only consider non-NaN
299  // values next.
300  //
301  // * If both Values are finite, there is no issue.
302  //
303  // * If one of the Values is infinite, or if both are infinities with the
304  // same sign, the result of floating point addition is the same infinity,
305  // so there is no issue.
306  //
307  // * If both of the Values are infinities with opposite signs, the result of
308  // adding IEEE floating point -inf + inf is NaN and hence not a Member. But
309  // since -inf was not a Member to begin with, returning a non-Member result
310  // is fine as well.
311  return TropicalWeightTpl<T>(w1.Value() + w2.Value());
312 }
313 
315  const TropicalWeightTpl<float> &w2) {
316  return Times<float>(w1, w2);
317 }
318 
320  const TropicalWeightTpl<double> &w2) {
321  return Times<double>(w1, w2);
322 }
323 
324 template <class T>
326  const TropicalWeightTpl<T> &w2,
327  DivideType typ = DIVIDE_ANY) {
328  // The following is safe in the context of the Tropical (and Log) semiring
329  // for all IEEE floating point values, including infinities and NaNs,
330  // because:
331  //
332  // * If one or both of the floating point Values is NaN and hence not a
333  // Member, the result of subtraction below is NaN, so the result is not a
334  // Member. This supersedes all other cases, so we only consider non-NaN
335  // values next.
336  //
337  // * If both Values are finite, there is no issue.
338  //
339  // * If w2.Value() is -inf (and hence w2 is not a Member), the result of ?:
340  // below is NoWeight, which is not a Member.
341  //
342  // Whereas in IEEE floating point semantics 0/inf == 0, this does not carry
343  // over to this semiring (since TropicalWeight(-inf) would be the analogue
344  // of floating point inf) and instead Divide(Zero(), TropicalWeight(-inf))
345  // is NoWeight().
346  //
347  // * If w2.Value() is inf (and hence w2 is Zero), the resulting floating
348  // point value is either NaN (if w1 is Zero or if w1.Value() is NaN) and
349  // hence not a Member, or it is -inf and hence not a Member; either way,
350  // division by Zero results in a non-Member result.
351  using Weight = TropicalWeightTpl<T>;
352  return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
353 }
354 
356  const TropicalWeightTpl<float> &w2,
357  DivideType typ = DIVIDE_ANY) {
358  return Divide<float>(w1, w2, typ);
359 }
360 
362  const TropicalWeightTpl<double> &w2,
363  DivideType typ = DIVIDE_ANY) {
364  return Divide<double>(w1, w2, typ);
365 }
366 
367 // Power(w, n) calculates the n-th power of w with respect to semiring Times.
368 //
369 // In the case of the Tropical (and Log) semiring, the exponent n is not
370 // restricted to be an integer. It can be a floating point value, for example.
371 //
372 // In weight.h, a narrower and hence more broadly applicable version of
373 // Power(w, n) is defined for arbitrary weight types and non-negative integer
374 // exponents n (of type size_t) and implemented in terms of repeated
375 // multiplication using Times.
376 //
377 // Without further provisions this means that, when an expression such as
378 //
379 // Power(TropicalWeightTpl<float>::One(), static_cast<size_t>(2))
380 //
381 // is specified, the overload of Power() is ambiguous. The template function
382 // below could be instantiated as
383 //
384 // Power<float, size_t>(const TropicalWeightTpl<float> &, size_t)
385 //
386 // and the template function defined in weight.h (further specialized below)
387 // could be instantiated as
388 //
389 // Power<TropicalWeightTpl<float>>(const TropicalWeightTpl<float> &, size_t)
390 //
391 // That would lead to two definitions with identical signatures, which results
392 // in a compilation error. To avoid that, we hide the definition of Power<T, V>
393 // when V is size_t, so only Power<W> is visible. Power<W> is further
394 // specialized to Power<TropicalWeightTpl<...>>, and the overloaded definition
395 // of Power<T, V> is made conditionally available only to that template
396 // specialization.
397 
398 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
399  typename std::enable_if_t<Enable> * = nullptr>
401  using Weight = TropicalWeightTpl<T>;
402  return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
403  : (n == 0 || w == Weight::One()) ? Weight::One()
404  : Weight(w.Value() * n);
405 }
406 
407 // Specializes the library-wide template to use the above implementation; rules
408 // of function template instantiation require this be a full instantiation.
409 
410 template <>
411 constexpr TropicalWeightTpl<float> Power<TropicalWeightTpl<float>>(
412  const TropicalWeightTpl<float> &weight, size_t n) {
413  return Power<float, size_t, true>(weight, n);
414 }
415 
416 template <>
417 constexpr TropicalWeightTpl<double> Power<TropicalWeightTpl<double>>(
418  const TropicalWeightTpl<double> &weight, size_t n) {
419  return Power<double, size_t, true>(weight, n);
420 }
421 
422 // Log semiring: (log(e^-x + e^-y), +, inf, 0).
423 template <class T>
424 class LogWeightTpl : public FloatWeightTpl<T> {
425  public:
426  using typename FloatWeightTpl<T>::ValueType;
430 
431  LogWeightTpl() noexcept : FloatWeightTpl<T>() {}
432 
433  constexpr LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
434 
435  static constexpr LogWeightTpl Zero() { return Limits::PosInfinity(); }
436 
437  static constexpr LogWeightTpl One() { return 0; }
438 
439  static constexpr LogWeightTpl NoWeight() { return Limits::NumberBad(); }
440 
441  static const std::string &Type() {
442  static const std::string *const type = new std::string(
444  return *type;
445  }
446 
447  constexpr bool Member() const {
448  // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
449  return Limits::NegInfinity() < Value();
450  }
451 
452  LogWeightTpl<T> Quantize(float delta = kDelta) const {
453  if (!Member() || Value() == Limits::PosInfinity()) {
454  return *this;
455  } else {
456  return LogWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
457  }
458  }
459 
460  constexpr LogWeightTpl<T> Reverse() const { return *this; }
461 
462  static constexpr uint64_t Properties() {
464  }
465 };
466 
467 // Single-precision log weight.
469 
470 // Double-precision log weight.
472 
473 namespace internal {
474 
475 // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming y >= x.
476 inline double LogPosExp(double x) {
477  DCHECK(!(x < 0)); // NB: NaN values are allowed.
478  return log1p(exp(-x));
479 }
480 
481 // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming y >= x.
482 inline double LogNegExp(double x) {
483  DCHECK(!(x < 0)); // NB: NaN values are allowed.
484  return log1p(-exp(-x));
485 }
486 
487 // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...).
488 // Kahan compensated summation provides an error bound that is
489 // independent of the number of addends. Assumes b >= a;
490 // c is the compensation.
491 inline double KahanLogSum(double a, double b, double *c) {
492  DCHECK_GE(b, a);
493  double y = -LogPosExp(b - a) - *c;
494  double t = a + y;
495  *c = (t - a) - y;
496  return t;
497 }
498 
499 // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...).
500 // Kahan compensated summation provides an error bound that is
501 // independent of the number of addends. Assumes b > a;
502 // c is the compensation.
503 inline double KahanLogDiff(double a, double b, double *c) {
504  DCHECK_GT(b, a);
505  double y = -LogNegExp(b - a) - *c;
506  double t = a + y;
507  *c = (t - a) - y;
508  return t;
509 }
510 
511 } // namespace internal
512 
513 template <class T>
515  const LogWeightTpl<T> &w2) {
516  using Limits = FloatLimits<T>;
517  const T f1 = w1.Value();
518  const T f2 = w2.Value();
519  if (f1 == Limits::PosInfinity()) {
520  return w2;
521  } else if (f2 == Limits::PosInfinity()) {
522  return w1;
523  } else if (f1 > f2) {
524  return LogWeightTpl<T>(f2 - internal::LogPosExp(f1 - f2));
525  } else {
526  return LogWeightTpl<T>(f1 - internal::LogPosExp(f2 - f1));
527  }
528 }
529 
531  const LogWeightTpl<float> &w2) {
532  return Plus<float>(w1, w2);
533 }
534 
536  const LogWeightTpl<double> &w2) {
537  return Plus<double>(w1, w2);
538 }
539 
540 // Returns NoWeight if w1 < w2 (w1.Value() > w2.Value()).
541 template <class T>
543  const LogWeightTpl<T> &w2) {
544  using Limits = FloatLimits<T>;
545  const T f1 = w1.Value();
546  const T f2 = w2.Value();
547  if (f1 > f2) return LogWeightTpl<T>::NoWeight();
548  if (f2 == Limits::PosInfinity()) return f1;
549  const T d = f2 - f1;
550  if (d == Limits::PosInfinity()) return f1;
551  return f1 - internal::LogNegExp(d);
552 }
553 
555  const LogWeightTpl<float> &w2) {
556  return Minus<float>(w1, w2);
557 }
558 
560  const LogWeightTpl<double> &w2) {
561  return Minus<double>(w1, w2);
562 }
563 
564 template <class T>
566  const LogWeightTpl<T> &w2) {
567  // The comments for Times(Tropical...) above apply here unchanged.
568  return LogWeightTpl<T>(w1.Value() + w2.Value());
569 }
570 
572  const LogWeightTpl<float> &w2) {
573  return Times<float>(w1, w2);
574 }
575 
577  const LogWeightTpl<double> &w2) {
578  return Times<double>(w1, w2);
579 }
580 
581 template <class T>
583  const LogWeightTpl<T> &w2,
584  DivideType typ = DIVIDE_ANY) {
585  // The comments for Divide(Tropical...) above apply here unchanged.
586  using Weight = LogWeightTpl<T>;
587  return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
588 }
589 
591  const LogWeightTpl<float> &w2,
592  DivideType typ = DIVIDE_ANY) {
593  return Divide<float>(w1, w2, typ);
594 }
595 
597  const LogWeightTpl<double> &w2,
598  DivideType typ = DIVIDE_ANY) {
599  return Divide<double>(w1, w2, typ);
600 }
601 
602 // The comments for Power<>(Tropical...) above apply here unchanged.
603 
604 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
605  typename std::enable_if_t<Enable> * = nullptr>
606 constexpr LogWeightTpl<T> Power(const LogWeightTpl<T> &w, V n) {
607  using Weight = LogWeightTpl<T>;
608  return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
609  : (n == 0 || w == Weight::One()) ? Weight::One()
610  : Weight(w.Value() * n);
611 }
612 
613 // Specializes the library-wide template to use the above implementation; rules
614 // of function template instantiation require this be a full instantiation.
615 
616 template <>
617 constexpr LogWeightTpl<float> Power<LogWeightTpl<float>>(
618  const LogWeightTpl<float> &weight, size_t n) {
619  return Power<float, size_t, true>(weight, n);
620 }
621 
622 template <>
623 constexpr LogWeightTpl<double> Power<LogWeightTpl<double>>(
624  const LogWeightTpl<double> &weight, size_t n) {
625  return Power<double, size_t, true>(weight, n);
626 }
627 
628 // Specialization using the Kahan compensated summation.
629 template <class T>
630 class Adder<LogWeightTpl<T>> {
631  public:
633 
634  explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
635 
636  Weight Add(const Weight &w) {
637  using Limits = FloatLimits<T>;
638  const T f = w.Value();
639  if (f == Limits::PosInfinity()) {
640  return Sum();
641  } else if (sum_ == Limits::PosInfinity()) {
642  sum_ = f;
643  c_ = 0.0;
644  } else if (f > sum_) {
645  sum_ = internal::KahanLogSum(sum_, f, &c_);
646  } else {
647  sum_ = internal::KahanLogSum(f, sum_, &c_);
648  }
649  return Sum();
650  }
651 
652  Weight Sum() const { return Weight(sum_); }
653 
654  void Reset(Weight w = Weight::Zero()) {
655  sum_ = w.Value();
656  c_ = 0.0;
657  }
658 
659  private:
660  double sum_;
661  double c_; // Kahan compensation.
662 };
663 
664 // Real semiring: (+, *, 0, 1).
665 template <class T>
666 class RealWeightTpl : public FloatWeightTpl<T> {
667  public:
668  using typename FloatWeightTpl<T>::ValueType;
672 
673  RealWeightTpl() noexcept : FloatWeightTpl<T>() {}
674 
675  constexpr RealWeightTpl(T f) : FloatWeightTpl<T>(f) {}
676 
677  static constexpr RealWeightTpl Zero() { return 0; }
678 
679  static constexpr RealWeightTpl One() { return 1; }
680 
681  static constexpr RealWeightTpl NoWeight() { return Limits::NumberBad(); }
682 
683  static const std::string &Type() {
684  static const std::string *const type = new std::string(
686  return *type;
687  }
688 
689  constexpr bool Member() const {
690  // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
691  return Limits::NegInfinity() < Value();
692  }
693 
694  RealWeightTpl<T> Quantize(float delta = kDelta) const {
695  if (!Member() || Value() == Limits::PosInfinity()) {
696  return *this;
697  } else {
698  return RealWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
699  }
700  }
701 
702  constexpr RealWeightTpl<T> Reverse() const { return *this; }
703 
704  static constexpr uint64_t Properties() {
706  }
707 };
708 
709 // Single-precision log weight.
711 
712 // Double-precision log weight.
714 
715 namespace internal {
716 
717 // a + b = KahanRealSum(a, b, ...).
718 // Kahan compensated summation provides an error bound that is
719 // independent of the number of addends. c is the compensation.
720 inline double KahanRealSum(double a, double b, double *c) {
721  double y = b - *c;
722  double t = a + y;
723  *c = (t - a) - y;
724  return t;
725 }
726 
727 }; // namespace internal
728 
729 // The comments for Times(Tropical...) above apply here unchanged.
730 template <class T>
732  const RealWeightTpl<T> &w2) {
733  const T f1 = w1.Value();
734  const T f2 = w2.Value();
735  return RealWeightTpl<T>(f1 + f2);
736 }
737 
739  const RealWeightTpl<float> &w2) {
740  return Plus<float>(w1, w2);
741 }
742 
744  const RealWeightTpl<double> &w2) {
745  return Plus<double>(w1, w2);
746 }
747 
748 template <class T>
750  const RealWeightTpl<T> &w2) {
751  // The comments for Divide(Tropical...) above apply here unchanged.
752  const T f1 = w1.Value();
753  const T f2 = w2.Value();
754  return RealWeightTpl<T>(f1 - f2);
755 }
756 
758  const RealWeightTpl<float> &w2) {
759  return Minus<float>(w1, w2);
760 }
761 
763  const RealWeightTpl<double> &w2) {
764  return Minus<double>(w1, w2);
765 }
766 
767 // The comments for Times(Tropical...) above apply here similarly.
768 template <class T>
770  const RealWeightTpl<T> &w2) {
771  return RealWeightTpl<T>(w1.Value() * w2.Value());
772 }
773 
775  const RealWeightTpl<float> &w2) {
776  return Times<float>(w1, w2);
777 }
778 
780  const RealWeightTpl<double> &w2) {
781  return Times<double>(w1, w2);
782 }
783 
784 template <class T>
786  const RealWeightTpl<T> &w2,
787  DivideType typ = DIVIDE_ANY) {
788  using Weight = RealWeightTpl<T>;
789  return w2.Member() ? Weight(w1.Value() / w2.Value()) : Weight::NoWeight();
790 }
791 
793  const RealWeightTpl<float> &w2,
794  DivideType typ = DIVIDE_ANY) {
795  return Divide<float>(w1, w2, typ);
796 }
797 
799  const RealWeightTpl<double> &w2,
800  DivideType typ = DIVIDE_ANY) {
801  return Divide<double>(w1, w2, typ);
802 }
803 
804 // The comments for Power<>(Tropical...) above apply here unchanged.
805 
806 template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
807  typename std::enable_if_t<Enable> * = nullptr>
808 constexpr RealWeightTpl<T> Power(const RealWeightTpl<T> &w, V n) {
809  using Weight = RealWeightTpl<T>;
810  return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
811  : (n == 0 || w == Weight::One()) ? Weight::One()
812  : Weight(pow(w.Value(), n));
813 }
814 
815 // Specializes the library-wide template to use the above implementation; rules
816 // of function template instantiation require this be a full instantiation.
817 
818 template <>
819 constexpr RealWeightTpl<float> Power<RealWeightTpl<float>>(
820  const RealWeightTpl<float> &weight, size_t n) {
821  return Power<float, size_t, true>(weight, n);
822 }
823 
824 template <>
825 constexpr RealWeightTpl<double> Power<RealWeightTpl<double>>(
826  const RealWeightTpl<double> &weight, size_t n) {
827  return Power<double, size_t, true>(weight, n);
828 }
829 
830 // Specialization using the Kahan compensated summation.
831 template <class T>
832 class Adder<RealWeightTpl<T>> {
833  public:
835 
836  explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
837 
838  Weight Add(const Weight &w) {
839  using Limits = FloatLimits<T>;
840  const T f = w.Value();
841  if (f == Limits::PosInfinity()) {
842  sum_ = f;
843  } else if (sum_ == Limits::PosInfinity()) {
844  return sum_;
845  } else {
846  sum_ = internal::KahanRealSum(sum_, f, &c_);
847  }
848  return Sum();
849  }
850 
851  Weight Sum() const { return Weight(sum_); }
852 
853  void Reset(Weight w = Weight::Zero()) {
854  sum_ = w.Value();
855  c_ = 0.0;
856  }
857 
858  private:
859  double sum_;
860  double c_; // Kahan compensation.
861 };
862 
863 // MinMax semiring: (min, max, inf, -inf).
864 template <class T>
865 class MinMaxWeightTpl : public FloatWeightTpl<T> {
866  public:
867  using typename FloatWeightTpl<T>::ValueType;
871 
872  MinMaxWeightTpl() noexcept : FloatWeightTpl<T>() {}
873 
874  constexpr MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} // NOLINT
875 
876  static constexpr MinMaxWeightTpl Zero() { return Limits::PosInfinity(); }
877 
878  static constexpr MinMaxWeightTpl One() { return Limits::NegInfinity(); }
879 
880  static constexpr MinMaxWeightTpl NoWeight() { return Limits::NumberBad(); }
881 
882  static const std::string &Type() {
883  static const std::string *const type = new std::string(
885  return *type;
886  }
887 
888  // Fails for IEEE NaN.
889  constexpr bool Member() const { return !internal::IsNan(Value()); }
890 
891  MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
892  // If one of infinities, or a NaN.
893  if (!Member() || Value() == Limits::NegInfinity() ||
894  Value() == Limits::PosInfinity()) {
895  return *this;
896  } else {
897  return MinMaxWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
898  }
899  }
900 
901  constexpr MinMaxWeightTpl<T> Reverse() const { return *this; }
902 
903  static constexpr uint64_t Properties() {
905  }
906 };
907 
908 // Single-precision min-max weight.
910 
911 // Min.
912 template <class T>
914  const MinMaxWeightTpl<T> &w2) {
915  return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
916  : w1.Value() < w2.Value() ? w1
917  : w2;
918 }
919 
921  const MinMaxWeightTpl<float> &w2) {
922  return Plus<float>(w1, w2);
923 }
924 
926  const MinMaxWeightTpl<double> &w2) {
927  return Plus<double>(w1, w2);
928 }
929 
930 // Max.
931 template <class T>
933  const MinMaxWeightTpl<T> &w2) {
934  return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
935  : w1.Value() >= w2.Value() ? w1
936  : w2;
937 }
938 
940  const MinMaxWeightTpl<float> &w2) {
941  return Times<float>(w1, w2);
942 }
943 
945  const MinMaxWeightTpl<double> &w2) {
946  return Times<double>(w1, w2);
947 }
948 
949 // Defined only for special cases.
950 template <class T>
952  const MinMaxWeightTpl<T> &w2,
953  DivideType typ = DIVIDE_ANY) {
954  return w1.Value() >= w2.Value() ? w1 : MinMaxWeightTpl<T>::NoWeight();
955 }
956 
958  const MinMaxWeightTpl<float> &w2,
959  DivideType typ = DIVIDE_ANY) {
960  return Divide<float>(w1, w2, typ);
961 }
962 
964  const MinMaxWeightTpl<double> &w2,
965  DivideType typ = DIVIDE_ANY) {
966  return Divide<double>(w1, w2, typ);
967 }
968 
969 // Converts to tropical.
970 template <>
972  constexpr TropicalWeight operator()(const LogWeight &w) const {
973  return w.Value();
974  }
975 };
976 
977 template <>
979  constexpr TropicalWeight operator()(const Log64Weight &w) const {
980  return w.Value();
981  }
982 };
983 
984 // Converts to log.
985 template <>
987  constexpr LogWeight operator()(const TropicalWeight &w) const {
988  return w.Value();
989  }
990 };
991 
992 template <>
994  LogWeight operator()(const RealWeight &w) const { return -log(w.Value()); }
995 };
996 
997 template <>
999  LogWeight operator()(const Real64Weight &w) const { return -log(w.Value()); }
1000 };
1001 
1002 template <>
1004  constexpr LogWeight operator()(const Log64Weight &w) const {
1005  return w.Value();
1006  }
1007 };
1008 
1009 // Converts to log64.
1010 template <>
1012  constexpr Log64Weight operator()(const TropicalWeight &w) const {
1013  return w.Value();
1014  }
1015 };
1016 
1017 template <>
1019  Log64Weight operator()(const RealWeight &w) const { return -log(w.Value()); }
1020 };
1021 
1022 template <>
1025  return -log(w.Value());
1026  }
1027 };
1028 
1029 template <>
1031  constexpr Log64Weight operator()(const LogWeight &w) const {
1032  return w.Value();
1033  }
1034 };
1035 
1036 // Converts to real.
1037 template <>
1039  RealWeight operator()(const LogWeight &w) const { return exp(-w.Value()); }
1040 };
1041 
1042 template <>
1044  RealWeight operator()(const Log64Weight &w) const { return exp(-w.Value()); }
1045 };
1046 
1047 template <>
1049  constexpr RealWeight operator()(const Real64Weight &w) const {
1050  return w.Value();
1051  }
1052 };
1053 
1054 // Converts to real64
1055 template <>
1057  Real64Weight operator()(const LogWeight &w) const { return exp(-w.Value()); }
1058 };
1059 
1060 template <>
1063  return exp(-w.Value());
1064  }
1065 };
1066 
1067 template <>
1069  constexpr Real64Weight operator()(const RealWeight &w) const {
1070  return w.Value();
1071  }
1072 };
1073 
1074 // This function object returns random integers chosen from [0,
1075 // num_random_weights). The allow_zero argument determines whether Zero() and
1076 // zero divisors should be returned in the random weight generation. This is
1077 // intended primary for testing.
1078 template <class Weight>
1080  public:
1082  uint64_t seed = std::random_device()(), bool allow_zero = true,
1083  const size_t num_random_weights = kNumRandomWeights)
1084  : rand_(seed),
1085  allow_zero_(allow_zero),
1086  num_random_weights_(num_random_weights) {}
1087 
1088  Weight operator()() const {
1089  const int sample = std::uniform_int_distribution<>(
1090  0, num_random_weights_ + allow_zero_ - 1)(rand_);
1091  if (allow_zero_ && sample == num_random_weights_) return Weight::Zero();
1092  return Weight(sample);
1093  }
1094 
1095  private:
1096  mutable std::mt19937_64 rand_;
1097  const bool allow_zero_;
1098  const size_t num_random_weights_;
1099 };
1100 
1101 template <class T>
1103  : public FloatWeightGenerate<TropicalWeightTpl<T>> {
1104  public:
1107 
1108  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1109  bool allow_zero = true,
1110  size_t num_random_weights = kNumRandomWeights)
1111  : Generate(seed, allow_zero, num_random_weights) {}
1112 
1113  Weight operator()() const { return Weight(Generate::operator()()); }
1114 };
1115 
1116 template <class T>
1118  : public FloatWeightGenerate<LogWeightTpl<T>> {
1119  public:
1122 
1123  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1124  bool allow_zero = true,
1125  size_t num_random_weights = kNumRandomWeights)
1126  : Generate(seed, allow_zero, num_random_weights) {}
1127 
1128  Weight operator()() const { return Weight(Generate::operator()()); }
1129 };
1130 
1131 template <class T>
1133  : public FloatWeightGenerate<RealWeightTpl<T>> {
1134  public:
1137 
1138  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1139  bool allow_zero = true,
1140  size_t num_random_weights = kNumRandomWeights)
1141  : Generate(seed, allow_zero, num_random_weights) {}
1142 
1143  Weight operator()() const { return Weight(Generate::operator()()); }
1144 };
1145 
1146 // This function object returns random integers chosen from [0,
1147 // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
1148 // zero divisors should be returned in the random weight generation. This is
1149 // intended primary for testing.
1150 template <class T>
1152  public:
1154 
1155  explicit WeightGenerate(uint64_t seed = std::random_device()(),
1156  bool allow_zero = true,
1157  size_t num_random_weights = kNumRandomWeights)
1158  : rand_(seed),
1159  allow_zero_(allow_zero),
1160  num_random_weights_(num_random_weights) {}
1161 
1162  Weight operator()() const {
1163  const int sample = std::uniform_int_distribution<>(
1164  -num_random_weights_, num_random_weights_ + allow_zero_)(rand_);
1165  if (allow_zero_ && sample == 0) {
1166  return Weight::Zero();
1167  } else if (sample == -num_random_weights_) {
1168  return Weight::One();
1169  } else {
1170  return Weight(sample);
1171  }
1172  }
1173 
1174  private:
1175  mutable std::mt19937_64 rand_;
1176  const bool allow_zero_;
1177  const size_t num_random_weights_;
1178 };
1179 
1180 } // namespace fst
1181 
1182 #endif // FST_FLOAT_WEIGHT_H_
RealWeight operator()(const Log64Weight &w) const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Weight Add(const Weight &w)
Definition: float-weight.h:838
static constexpr RealWeightTpl Zero()
Definition: float-weight.h:677
void Reset(Weight w=Weight::Zero())
Definition: float-weight.h:654
static constexpr TropicalWeightTpl< T > Zero()
Definition: float-weight.h:216
constexpr TropicalWeight operator()(const LogWeight &w) const
Definition: float-weight.h:972
FloatWeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, const size_t num_random_weights=kNumRandomWeights)
constexpr RealWeight operator()(const Real64Weight &w) const
constexpr Log64Weight operator()(const LogWeight &w) const
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
constexpr bool Member() const
Definition: float-weight.h:689
Adder(Weight w=Weight::Zero())
Definition: float-weight.h:836
RealWeight operator()(const LogWeight &w) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
static constexpr std::string_view GetPrecisionString()
Definition: float-weight.h:100
RealWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:694
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:435
static constexpr MinMaxWeightTpl NoWeight()
Definition: float-weight.h:880
static constexpr MinMaxWeightTpl One()
Definition: float-weight.h:878
constexpr uint64_t kIdempotent
Definition: weight.h:147
RealWeightTpl() noexcept
Definition: float-weight.h:673
#define DCHECK_GT(x, y)
Definition: log.h:77
std::ostream & Write(std::ostream &strm) const
Definition: float-weight.h:81
Weight operator()() const
constexpr Log64Weight operator()(const TropicalWeight &w) const
void SetValue(const T &f)
Definition: float-weight.h:98
static constexpr uint64_t Properties()
Definition: float-weight.h:261
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
static constexpr T PosInfinity()
Definition: float-weight.h:60
constexpr RealWeightTpl< T > Reverse() const
Definition: float-weight.h:702
static constexpr uint64_t Properties()
Definition: float-weight.h:903
TropicalWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:251
constexpr uint64_t kRightSemiring
Definition: weight.h:139
std::ostream & WriteType(std::ostream &strm, const T t)
Definition: util.h:229
static constexpr LogWeightTpl One()
Definition: float-weight.h:437
static constexpr TropicalWeightTpl< T > One()
Definition: float-weight.h:218
constexpr bool FloatApproxEqual(T w1, T w2, float delta=kDelta)
Definition: float-weight.h:159
static constexpr T NumberBad()
Definition: float-weight.h:66
std::istream & operator>>(std::istream &strm, FloatWeightTpl< T > &w)
Definition: float-weight.h:184
Log64Weight operator()(const Real64Weight &w) const
Real64Weight operator()(const LogWeight &w) const
constexpr bool IsNan(T value)
Definition: float-weight.h:51
constexpr TropicalWeightTpl< T > Reverse() const
Definition: float-weight.h:259
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:54
constexpr LogWeight operator()(const TropicalWeight &w) const
Definition: float-weight.h:987
MinMaxWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:891
static constexpr LogWeightTpl NoWeight()
Definition: float-weight.h:439
constexpr uint64_t kCommutative
Definition: weight.h:144
std::ostream & operator<<(std::ostream &strm, const ErrorWeight &)
Definition: error-weight.h:71
LogWeight operator()(const RealWeight &w) const
Definition: float-weight.h:994
Adder(Weight w=Weight::Zero())
Definition: float-weight.h:634
void Reset(Weight w=Weight::Zero())
Definition: float-weight.h:853
constexpr MinMaxWeightTpl(T f)
Definition: float-weight.h:874
static constexpr RealWeightTpl NoWeight()
Definition: float-weight.h:681
static const std::string & Type()
Definition: float-weight.h:683
LogWeightTpl< T > Quantize(float delta=kDelta) const
Definition: float-weight.h:452
static constexpr TropicalWeightTpl< T > NoWeight()
Definition: float-weight.h:220
static constexpr uint64_t Properties()
Definition: float-weight.h:704
static constexpr RealWeightTpl One()
Definition: float-weight.h:679
constexpr LogWeightTpl< T > Reverse() const
Definition: float-weight.h:460
double LogPosExp(double x)
Definition: float-weight.h:476
double KahanLogSum(double a, double b, double *c)
Definition: float-weight.h:491
TropicalWeightTpl() noexcept
Definition: float-weight.h:212
LogWeight operator()(const Real64Weight &w) const
Definition: float-weight.h:999
std::istream & Read(std::istream &strm)
Definition: float-weight.h:79
constexpr uint64_t kPath
Definition: weight.h:150
static const std::string & Type()
Definition: float-weight.h:224
Log64Weight operator()(const RealWeight &w) const
Weight Add(const Weight &w)
Definition: float-weight.h:636
LogWeightTpl() noexcept
Definition: float-weight.h:431
constexpr bool Member() const
Definition: float-weight.h:230
double LogNegExp(double x)
Definition: float-weight.h:482
constexpr RealWeightTpl(T f)
Definition: float-weight.h:675
MinMaxWeightTpl() noexcept
Definition: float-weight.h:872
LogWeightTpl< T > Minus(const LogWeightTpl< T > &w1, const LogWeightTpl< T > &w2)
Definition: float-weight.h:542
constexpr bool Member() const
Definition: float-weight.h:889
std::string StrCat(const StringOrInt &s1, const StringOrInt &s2)
Definition: compat.h:295
static const std::string & Type()
Definition: float-weight.h:441
constexpr bool Member() const
Definition: float-weight.h:447
size_t Hash() const
Definition: float-weight.h:85
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:51
constexpr size_t kNumRandomWeights
Definition: weight.h:154
static constexpr MinMaxWeightTpl Zero()
Definition: float-weight.h:876
constexpr TropicalWeight operator()(const Log64Weight &w) const
Definition: float-weight.h:979
static const std::string & Type()
Definition: float-weight.h:882
std::istream & ReadType(std::istream &strm, T *t)
Definition: util.h:81
constexpr MinMaxWeightTpl< T > Reverse() const
Definition: float-weight.h:901
#define DCHECK(x)
Definition: log.h:74
static constexpr uint64_t Properties()
Definition: float-weight.h:462
constexpr TropicalWeightTpl< T > Power(const TropicalWeightTpl< T > &w, V n)
Definition: float-weight.h:400
#define DCHECK_GE(x, y)
Definition: log.h:79
constexpr LogWeight operator()(const Log64Weight &w) const
DivideType
Definition: weight.h:165
constexpr uint64_t kLeftSemiring
Definition: weight.h:136
constexpr float kDelta
Definition: weight.h:133
constexpr Real64Weight operator()(const RealWeight &w) const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
Real64Weight operator()(const Log64Weight &w) const
constexpr const T & Value() const
Definition: float-weight.h:95
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:58
constexpr LogWeightTpl(T f)
Definition: float-weight.h:433
constexpr TropicalWeightTpl(T f)
Definition: float-weight.h:214
double KahanLogDiff(double a, double b, double *c)
Definition: float-weight.h:503
static constexpr T NegInfinity()
Definition: float-weight.h:64
double KahanRealSum(double a, double b, double *c)
Definition: float-weight.h:720