FST  openfst-1.8.3
OpenFst Library
signed-log-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // LogWeight along with sign information that represents the value X in the
19 // linear domain as <sign(X), -ln(|X|)>
20 //
21 // The sign is a TropicalWeight:
22 // positive, TropicalWeight.Value() > 0.0, recommended value 1.0
23 // negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
24 
25 #ifndef FST_SIGNED_LOG_WEIGHT_H_
26 #define FST_SIGNED_LOG_WEIGHT_H_
27 
28 #include <climits>
29 #include <cmath>
30 #include <cstddef>
31 #include <cstdint>
32 #include <cstdlib>
33 #include <random>
34 #include <string>
35 
36 #include <fst/log.h>
37 #include <fst/float-weight.h>
38 #include <fst/pair-weight.h>
39 #include <fst/product-weight.h>
40 #include <fst/util.h>
41 #include <fst/weight.h>
42 
43 namespace fst {
44 template <class T>
45 class SignedLogWeightTpl : public PairWeight<TropicalWeight, LogWeightTpl<T>> {
46  public:
47  using W1 = TropicalWeight;
50 
53 
54  SignedLogWeightTpl() noexcept : PairWeight<W1, W2>() {}
55 
56  // Conversion from plain LogWeightTpl.
57  // NOLINTNEXTLINE(google-explicit-constructor)
58  SignedLogWeightTpl(const W2 &w2) : PairWeight<W1, W2>(W1(1.0), w2) {}
59 
60  explicit SignedLogWeightTpl(const PairWeight<W1, W2> &weight)
61  : PairWeight<W1, W2>(weight) {}
62 
63  SignedLogWeightTpl(const W1 &w1, const W2 &w2) : PairWeight<W1, W2>(w1, w2) {}
64 
65  static const SignedLogWeightTpl &Zero() {
66  static const SignedLogWeightTpl zero(W1(1.0), W2::Zero());
67  return zero;
68  }
69 
70  static const SignedLogWeightTpl &One() {
71  static const SignedLogWeightTpl one(W1(1.0), W2::One());
72  return one;
73  }
74 
75  static const SignedLogWeightTpl &NoWeight() {
76  static const SignedLogWeightTpl no_weight(W1(1.0), W2::NoWeight());
77  return no_weight;
78  }
79 
80  static const std::string &Type() {
81  static const std::string *const type =
82  new std::string("signed_log_" + W1::Type() + "_" + W2::Type());
83  return *type;
84  }
85 
86  bool IsPositive() const { return Value1().Value() > 0; }
87 
88  SignedLogWeightTpl Quantize(float delta = kDelta) const {
90  }
91 
94  }
95 
96  bool Member() const { return PairWeight<W1, W2>::Member(); }
97 
98  // Neither idempotent nor path.
99  static constexpr uint64_t Properties() {
101  }
102 
103  size_t Hash() const {
104  size_t h1;
105  if (Value2() == W2::Zero() || IsPositive()) {
106  h1 = TropicalWeight(1.0).Hash();
107  } else {
108  h1 = TropicalWeight(-1.0).Hash();
109  }
110  size_t h2 = Value2().Hash();
111  static constexpr int lshift = 5;
112  static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5;
113  return h1 << lshift ^ h1 >> rshift ^ h2;
114  }
115 };
116 
117 template <class T>
119  const SignedLogWeightTpl<T> &w2) {
120  using W1 = TropicalWeight;
121  using W2 = LogWeightTpl<T>;
122  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
123  const auto s1 = w1.IsPositive();
124  const auto s2 = w2.IsPositive();
125  const bool equal = (s1 == s2);
126  const auto f1 = w1.Value2().Value();
127  const auto f2 = w2.Value2().Value();
128  if (f1 == FloatLimits<T>::PosInfinity()) {
129  return w2;
130  } else if (f2 == FloatLimits<T>::PosInfinity()) {
131  return w1;
132  } else if (f1 == f2) {
133  if (equal) {
134  return SignedLogWeightTpl<T>(W1(w1.Value1()), W2(f2 - M_LN2));
135  } else {
137  }
138  } else if (f1 > f2) {
139  if (equal) {
140  return SignedLogWeightTpl<T>(W1(w1.Value1()),
141  W2(f2 - internal::LogPosExp(f1 - f2)));
142  } else {
143  return SignedLogWeightTpl<T>(W1(w2.Value1()),
144  W2((f2 - internal::LogNegExp(f1 - f2))));
145  }
146  } else {
147  if (equal) {
148  return SignedLogWeightTpl<T>(W1(w2.Value1()),
149  W2((f1 - internal::LogPosExp(f2 - f1))));
150  } else {
151  return SignedLogWeightTpl<T>(W1(w1.Value1()),
152  W2((f1 - internal::LogNegExp(f2 - f1))));
153  }
154  }
155 }
156 
157 template <class T>
159  const SignedLogWeightTpl<T> &w2) {
160  SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
161  return Plus(w1, minus_w2);
162 }
163 
164 template <class T>
166  const SignedLogWeightTpl<T> &w2) {
167  using W2 = LogWeightTpl<T>;
168  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
169  const auto s1 = w1.IsPositive();
170  const auto s2 = w2.IsPositive();
171  const auto f1 = w1.Value2().Value();
172  const auto f2 = w2.Value2().Value();
173  if (s1 == s2) {
174  return SignedLogWeightTpl<T>(TropicalWeight(1.0), W2(f1 + f2));
175  } else {
176  return SignedLogWeightTpl<T>(TropicalWeight(-1.0), W2(f1 + f2));
177  }
178 }
179 
180 template <class T>
182  const SignedLogWeightTpl<T> &w2,
183  DivideType typ = DIVIDE_ANY) {
184  using W2 = LogWeightTpl<T>;
185  if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
186  const auto s1 = w1.IsPositive();
187  const auto s2 = w2.IsPositive();
188  const auto f1 = w1.Value2().Value();
189  const auto f2 = w2.Value2().Value();
190  if (f2 == FloatLimits<T>::PosInfinity()) {
193  } else if (f1 == FloatLimits<T>::PosInfinity()) {
196  } else if (s1 == s2) {
197  return SignedLogWeightTpl<T>(TropicalWeight(1.0), W2(f1 - f2));
198  } else {
199  return SignedLogWeightTpl<T>(TropicalWeight(-1.0), W2(f1 - f2));
200  }
201 }
202 
203 template <class T>
204 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
205  const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
206  using W2 = LogWeightTpl<T>;
207  if (w1.IsPositive() == w2.IsPositive()) {
208  return ApproxEqual(w1.Value2(), w2.Value2(), delta);
209  } else {
210  return ApproxEqual(w1.Value2(), W2::Zero(), delta) &&
211  ApproxEqual(w2.Value2(), W2::Zero(), delta);
212  }
213 }
214 
215 template <class T>
216 inline bool operator==(const SignedLogWeightTpl<T> &w1,
217  const SignedLogWeightTpl<T> &w2) {
218  using W2 = LogWeightTpl<T>;
219  if (w1.IsPositive() == w2.IsPositive()) {
220  return w1.Value2() == w2.Value2();
221  } else {
222  return w1.Value2() == W2::Zero() && w2.Value2() == W2::Zero();
223  }
224 }
225 
226 template <class T>
227 inline bool operator!=(const SignedLogWeightTpl<T> &w1,
228  const SignedLogWeightTpl<T> &w2) {
229  return !(w1 == w2);
230 }
231 
232 // All functions and operators with a LogWeightTpl arg need to be
233 // explicitly specified since the implicit constructor will not be
234 // tried in conjunction with function overloading.
235 
236 template <class T>
238  const SignedLogWeightTpl<T> &w2) {
239  return Plus(SignedLogWeightTpl<T>(w1), w2);
240 }
241 
242 template <class T>
244  const LogWeightTpl<T> &w2) {
245  return Plus(w1, SignedLogWeightTpl<T>(w2));
246 }
247 
248 template <class T>
250  const SignedLogWeightTpl<T> &w2) {
251  return Minus(SignedLogWeightTpl<T>(w1), w2);
252 }
253 
254 template <class T>
256  const LogWeightTpl<T> &w2) {
257  return Minus(w1, SignedLogWeightTpl<T>(w2));
258 }
259 
260 template <class T>
262  const SignedLogWeightTpl<T> &w2) {
263  return Times(SignedLogWeightTpl<T>(w1), w2);
264 }
265 
266 template <class T>
268  const LogWeightTpl<T> &w2) {
269  return Times(w1, SignedLogWeightTpl<T>(w2));
270 }
271 
272 template <class T>
274  const SignedLogWeightTpl<T> &w2,
275  DivideType typ = DIVIDE_ANY) {
276  return Divide(SignedLogWeightTpl<T>(w1), w2, typ);
277 }
278 
279 template <class T>
281  const LogWeightTpl<T> &w2,
282  DivideType typ = DIVIDE_ANY) {
283  return Divide(w1, SignedLogWeightTpl<T>(w2), typ);
284 }
285 
286 template <class T>
287 inline bool ApproxEqual(const LogWeightTpl<T> &w1,
288  const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
289  return ApproxEqual(LogWeightTpl<T>(w1), w2, delta);
290 }
291 
292 template <class T>
293 inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
294  const LogWeightTpl<T> &w2, float delta = kDelta) {
295  return ApproxEqual(w1, LogWeightTpl<T>(w2), delta);
296 }
297 
298 template <class T>
299 inline bool operator==(const LogWeightTpl<T> &w1,
300  const SignedLogWeightTpl<T> &w2) {
301  return SignedLogWeightTpl<T>(w1) == w2;
302 }
303 
304 template <class T>
305 inline bool operator==(const SignedLogWeightTpl<T> &w1,
306  const LogWeightTpl<T> &w2) {
307  return w1 == SignedLogWeightTpl<T>(w2);
308 }
309 
310 template <class T>
311 inline bool operator!=(const LogWeightTpl<T> &w1,
312  const SignedLogWeightTpl<T> &w2) {
313  return SignedLogWeightTpl<T>(w1) != w2;
314 }
315 
316 template <class T>
317 inline bool operator!=(const SignedLogWeightTpl<T> &w1,
318  const LogWeightTpl<T> &w2) {
319  return w1 != SignedLogWeightTpl<T>(w2);
320 }
321 
322 // Single-precision signed-log weight.
324 
325 // Double-precision signed-log weight.
327 
328 template <class W1, class W2>
329 bool SignedLogConvertCheck(W1 weight) {
330  if (weight.Value1().Value() < 0.0) {
331  FSTERROR() << "WeightConvert: Can't convert weight " << weight << " from "
332  << W1::Type() << " to " << W2::Type();
333  return false;
334  }
335  return true;
336 }
337 
338 // Specialization using the Kahan compensated summation
339 template <class T>
341  public:
345 
346  explicit Adder(Weight w = Weight::Zero())
347  : ssum_(w.IsPositive()), sum_(w.Value2().Value()), c_(0.0) {}
348 
349  Weight Add(const Weight &w) {
350  const auto sw = w.IsPositive();
351  const auto f = w.Value2().Value();
352  const bool equal = (ssum_ == sw);
353 
354  if (!Sum().Member() || f == FloatLimits<T>::PosInfinity()) {
355  return Sum();
356  } else if (!w.Member() || sum_ == FloatLimits<T>::PosInfinity()) {
357  sum_ = f;
358  ssum_ = sw;
359  c_ = 0.0;
360  } else if (f == sum_) {
361  if (equal) {
362  sum_ = internal::KahanLogSum(sum_, f, &c_);
363  } else {
365  ssum_ = true;
366  c_ = 0.0;
367  }
368  } else if (f > sum_) {
369  if (equal) {
370  sum_ = internal::KahanLogSum(sum_, f, &c_);
371  } else {
372  sum_ = internal::KahanLogDiff(sum_, f, &c_);
373  }
374  } else {
375  if (equal) {
376  sum_ = internal::KahanLogSum(f, sum_, &c_);
377  } else {
378  sum_ = internal::KahanLogDiff(f, sum_, &c_);
379  ssum_ = sw;
380  }
381  }
382  return Sum();
383  }
384 
385  Weight Sum() const { return Weight(W1(ssum_ ? 1.0 : -1.0), W2(sum_)); }
386 
387  void Reset(Weight w = Weight::Zero()) {
388  ssum_ = w.IsPositive();
389  sum_ = w.Value2().Value();
390  c_ = 0.0;
391  }
392 
393  private:
394  bool ssum_; // true iff sign of sum is positive
395  double sum_; // unsigned sum
396  double c_; // Kahan compensation
397 };
398 
399 // Converts to tropical.
400 template <>
403  if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(weight)) {
404  return TropicalWeight::NoWeight();
405  }
406  return TropicalWeight(weight.Value2().Value());
407  }
408 };
409 
410 template <>
413  if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(weight)) {
414  return TropicalWeight::NoWeight();
415  }
416  return TropicalWeight(weight.Value2().Value());
417  }
418 };
419 
420 // Converts to log.
421 template <>
423  LogWeight operator()(const SignedLogWeight &weight) const {
424  if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(weight)) {
425  return LogWeight::NoWeight();
426  }
427  return LogWeight(weight.Value2().Value());
428  }
429 };
430 
431 template <>
433  LogWeight operator()(const SignedLog64Weight &weight) const {
434  if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(weight)) {
435  return LogWeight::NoWeight();
436  }
437  return LogWeight(weight.Value2().Value());
438  }
439 };
440 
441 // Converts to log64.
442 template <>
444  Log64Weight operator()(const SignedLogWeight &weight) const {
445  if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(weight)) {
446  return Log64Weight::NoWeight();
447  }
448  return Log64Weight(weight.Value2().Value());
449  }
450 };
451 
452 template <>
454  Log64Weight operator()(const SignedLog64Weight &weight) const {
455  if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(weight)) {
456  return Log64Weight::NoWeight();
457  }
458  return Log64Weight(weight.Value2().Value());
459  }
460 };
461 
462 // Converts to real.
463 template <>
465  RealWeight operator()(const SignedLogWeight &weight) const {
466  return RealWeight(weight.Value1().Value() * exp(-weight.Value2().Value()));
467  }
468 };
469 
470 template <>
472  RealWeight operator()(const SignedLog64Weight &weight) const {
473  return RealWeight(weight.Value1().Value() * exp(-weight.Value2().Value()));
474  }
475 };
476 
477 // Converts to real64.
478 template <>
480  Real64Weight operator()(const SignedLogWeight &weight) const {
481  return Real64Weight(weight.Value1().Value() *
482  exp(-weight.Value2().Value()));
483  }
484 };
485 
486 template <>
489  return Real64Weight(weight.Value1().Value() *
490  exp(-weight.Value2().Value()));
491  }
492 };
493 
494 // Converts to signed log.
495 template <>
498  return SignedLogWeight(1.0, weight.Value());
499  }
500 };
501 
502 template <>
504  SignedLogWeight operator()(const LogWeight &weight) const {
505  return SignedLogWeight(1.0, weight.Value());
506  }
507 };
508 
509 template <>
511  SignedLogWeight operator()(const Log64Weight &weight) const {
512  return SignedLogWeight(1.0, weight.Value());
513  }
514 };
515 
516 template <>
518  SignedLogWeight operator()(const RealWeight &weight) const {
519  return SignedLogWeight(weight.Value() >= 0 ? 1.0 : -1.0,
520  -log(std::abs(weight.Value())));
521  }
522 };
523 
524 template <>
526  SignedLogWeight operator()(const Real64Weight &weight) const {
527  return SignedLogWeight(weight.Value() >= 0 ? 1.0 : -1.0,
528  -log(std::abs(weight.Value())));
529  }
530 };
531 
532 template <>
535  return SignedLogWeight(weight.Value1(), weight.Value2().Value());
536  }
537 };
538 
539 // Converts to signed log64.
540 template <>
543  return SignedLog64Weight(1.0, weight.Value());
544  }
545 };
546 
547 template <>
549  SignedLog64Weight operator()(const LogWeight &weight) const {
550  return SignedLog64Weight(1.0, weight.Value());
551  }
552 };
553 
554 template <>
556  SignedLog64Weight operator()(const Log64Weight &weight) const {
557  return SignedLog64Weight(1.0, weight.Value());
558  }
559 };
560 
561 template <>
563  SignedLog64Weight operator()(const RealWeight &weight) const {
564  return SignedLog64Weight(weight.Value() >= 0 ? 1.0 : -1.0,
565  -log(std::abs(weight.Value())));
566  }
567 };
568 
569 template <>
572  return SignedLog64Weight(weight.Value() >= 0 ? 1.0 : -1.0,
573  -log(std::abs(weight.Value())));
574  }
575 };
576 
577 template <>
580  return SignedLog64Weight(weight.Value1(), weight.Value2().Value());
581  }
582 };
583 
584 // This function object returns SignedLogWeightTpl<T>'s that are random integers
585 // chosen from [0, num_random_weights) times a random sign. This is intended
586 // primarily for testing.
587 template <class T>
589  public:
591  using W1 = typename Weight::W1;
592  using W2 = typename Weight::W2;
593 
594  explicit WeightGenerate(uint64_t seed = std::random_device()(),
595  bool allow_zero = true,
596  size_t num_random_weights = kNumRandomWeights)
597  : rand_(seed),
598  allow_zero_(allow_zero),
599  num_random_weights_(num_random_weights) {}
600 
601  Weight operator()() const {
602  static constexpr W1 negative(-1.0);
603  static constexpr W1 positive(+1.0);
604  const bool sign = std::bernoulli_distribution(.5)(rand_);
605  const int sample = std::uniform_int_distribution<>(
606  0, num_random_weights_ + allow_zero_ - 1)(rand_);
607  if (allow_zero_ && sample == num_random_weights_) {
608  return Weight(sign ? positive : negative, W2::Zero());
609  }
610  return Weight(sign ? positive : negative, W2(sample));
611  }
612 
613  private:
614  mutable std::mt19937_64 rand_;
615  const bool allow_zero_;
616  const size_t num_random_weights_;
617 };
618 
619 } // namespace fst
620 
621 #endif // FST_SIGNED_LOG_WEIGHT_H_
static const SignedLogWeightTpl & NoWeight()
Real64Weight operator()(const SignedLogWeight &weight) const
TropicalWeight operator()(const SignedLog64Weight &weight) const
Real64Weight operator()(const SignedLog64Weight &weight) const
RealWeightTpl< float > RealWeight
Definition: float-weight.h:711
static const SignedLogWeightTpl & Zero()
LogWeight operator()(const SignedLogWeight &weight) const
SignedLog64Weight operator()(const RealWeight &weight) const
RealWeight operator()(const SignedLogWeight &weight) const
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
SignedLogWeightTpl(const PairWeight< W1, W2 > &weight)
void Reset(Weight w=Weight::Zero())
SignedLog64Weight operator()(const TropicalWeight &weight) const
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
bool Member() const
Definition: pair-weight.h:75
SignedLogWeight operator()(const TropicalWeight &weight) const
const LogWeightTpl< T > & Value2() const
Definition: pair-weight.h:95
static constexpr uint64_t Properties()
static constexpr LogWeightTpl Zero()
Definition: float-weight.h:436
SignedLogWeightTpl Quantize(float delta=kDelta) const
SignedLogWeight operator()(const Log64Weight &weight) const
static constexpr T PosInfinity()
Definition: float-weight.h:61
constexpr uint64_t kRightSemiring
Definition: weight.h:139
static constexpr LogWeightTpl One()
Definition: float-weight.h:438
#define FSTERROR()
Definition: util.h:56
LogWeightTpl< double > Log64Weight
Definition: float-weight.h:472
Log64Weight operator()(const SignedLog64Weight &weight) const
LogWeightTpl< float > LogWeight
Definition: float-weight.h:469
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true, size_t num_random_weights=kNumRandomWeights)
TropicalWeightTpl< float > TropicalWeight
Definition: float-weight.h:268
bool operator!=(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:54
static const SignedLogWeightTpl & One()
static constexpr LogWeightTpl NoWeight()
Definition: float-weight.h:440
constexpr uint64_t kCommutative
Definition: weight.h:144
SignedLogWeightTpl(const W1 &w1, const W2 &w2)
SignedLog64Weight operator()(const SignedLogWeight &weight) const
static constexpr TropicalWeightTpl< T > NoWeight()
Definition: float-weight.h:221
double LogPosExp(double x)
Definition: float-weight.h:477
double KahanLogSum(double a, double b, double *c)
Definition: float-weight.h:492
RealWeight operator()(const SignedLog64Weight &weight) const
SignedLogWeightTpl< double > SignedLog64Weight
static const std::string & Type()
Definition: float-weight.h:225
SignedLog64Weight operator()(const Real64Weight &weight) const
SignedLogWeight operator()(const RealWeight &weight) const
double LogNegExp(double x)
Definition: float-weight.h:483
LogWeightTpl< T > Minus(const LogWeightTpl< T > &w1, const LogWeightTpl< T > &w2)
Definition: float-weight.h:543
SignedLogWeightTpl(const W2 &w2)
static const std::string & Type()
Definition: float-weight.h:442
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
bool operator==(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:51
constexpr size_t kNumRandomWeights
Definition: weight.h:154
Log64Weight operator()(const SignedLogWeight &weight) const
bool SignedLogConvertCheck(W1 weight)
RealWeightTpl< double > Real64Weight
Definition: float-weight.h:714
SignedLog64Weight operator()(const Log64Weight &weight) const
SignedLogWeight operator()(const SignedLog64Weight &weight) const
ReverseWeight Reverse() const
SignedLogWeight operator()(const Real64Weight &weight) const
SignedLog64Weight operator()(const LogWeight &weight) const
SignedLogWeight operator()(const LogWeight &weight) const
DivideType
Definition: weight.h:165
LogWeight operator()(const SignedLog64Weight &weight) const
constexpr uint64_t kLeftSemiring
Definition: weight.h:136
constexpr float kDelta
Definition: weight.h:133
SignedLogWeightTpl< float > SignedLogWeight
static const std::string & Type()
constexpr const T & Value() const
Definition: float-weight.h:96
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
Definition: error-weight.h:58
TropicalWeight operator()(const SignedLogWeight &weight) const
double KahanLogDiff(double a, double b, double *c)
Definition: float-weight.h:504