FST  openfst-1.8.4
OpenFst Library
expectation-weight.h
Go to the documentation of this file.
1 // Copyright 2005-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Expectation semiring, as described in:
19 //
20 // Eisner, J. 2002. Parameter estimation for probabilistic finite-state
21 // transducers. In Proceedings of the 40th Annual Meeting of the
22 // Association for Computational Linguistics, pages 1-8.
23 //
24 // Multiplex semiring operations and identities:
25 //
26 // One: <One, Zero>
27 // Zero: <Zero, Zero>
28 // Plus: <a1, b1> + <a2, b2> = <(a1 + a2), (b1 + b2)>
29 // Times: <a1, b1> + <a2, b2> = <(a1 * a2), [(a1 * b2) + (a2 * b1)]>
30 // Division (see Divide() for proof):
31 // For a left-semiring:
32 // <a1, b1> / <a2, b2> = <a1 / a2, (b1 - b2 * (a1 / a2)) / a2>
33 // For a right-semiring:
34 // <a1, b1> / <a2, b2> = <a1 / a2, (b1 - (a1 / a2) * b2) / a2>
35 //
36 // It is commonly used to store a probability, random variable pair so that
37 // the shortest distance gives the posterior probability and the associated
38 // expected value.
39 
40 #ifndef FST_EXPECTATION_WEIGHT_H_
41 #define FST_EXPECTATION_WEIGHT_H_
42 
43 #include <cstdint>
44 #include <random>
45 #include <string>
46 
47 #include <fst/log.h>
48 #include <fst/pair-weight.h>
49 #include <fst/weight.h>
50 
51 namespace fst {
52 
53 // W1 is usually a probability weight like LogWeight.
54 // W2 is usually a random variable or vector (see SignedLogWeight or
55 // SparsePowerWeight).
56 //
57 // If W1 is distinct from W2, it is required that there is an external product
58 // between W1 and W2 (that is, both Times(W1, W2) -> W2 and Times(W2, W1) -> W2
59 // must be defined) and if both semirings are commutative, or left or right
60 // semirings, then the result must have those properties.
61 template <class W1, class W2>
62 class ExpectationWeight : public PairWeight<W1, W2> {
63  public:
66 
70 
71  using ReverseWeight =
73 
74  ExpectationWeight() : PairWeight<W1, W2>(Zero()) {}
75 
76  explicit ExpectationWeight(const PairWeight<W1, W2> &weight)
77  : PairWeight<W1, W2>(weight) {}
78 
79  ExpectationWeight(const W1 &w1, const W2 &w2) : PairWeight<W1, W2>(w1, w2) {}
80 
81  static const ExpectationWeight &Zero() {
82  static const ExpectationWeight zero(W1::Zero(), W2::Zero());
83  return zero;
84  }
85 
86  static const ExpectationWeight &One() {
87  static const ExpectationWeight one(W1::One(), W2::Zero());
88  return one;
89  }
90 
91  static const ExpectationWeight &NoWeight() {
92  static const ExpectationWeight no_weight(W1::NoWeight(), W2::NoWeight());
93  return no_weight;
94  }
95 
96  static const std::string &Type() {
97  static const std::string *const type =
98  new std::string("expectation_" + W1::Type() + "_" + W2::Type());
99  return *type;
100  }
101 
102  ExpectationWeight Quantize(float delta = kDelta) const {
104  }
105 
108  }
109 
110  bool Member() const { return PairWeight<W1, W2>::Member(); }
111 
112  static constexpr uint64_t Properties() {
113  return W1::Properties() & W2::Properties() &
115  }
116 };
117 
118 template <class W1, class W2>
120  const ExpectationWeight<W1, W2> &w2) {
121  return ExpectationWeight<W1, W2>(Plus(w1.Value1(), w2.Value1()),
122  Plus(w1.Value2(), w2.Value2()));
123 }
124 
125 template <class W1, class W2>
127  const ExpectationWeight<W1, W2> &w2) {
129  Times(w1.Value1(), w2.Value1()),
130  Plus(Times(w1.Value1(), w2.Value2()), Times(w1.Value2(), w2.Value1())));
131 }
132 
133 // Requires
134 // * Divide(W1, W1) -> W1
135 // * Divide(W2, W1) -> W2
136 // * Times(W1, W2) -> W2
137 // (already required by Times(ExpectationWeight, ExpectationWeight).)
138 // * Minus(W2, W2) -> W2
139 // (not part of the Weight interface, so Divide will not compile if
140 // Minus is not defined).
141 template <class W1, class W2>
143  const ExpectationWeight<W1, W2> &w2,
144  DivideType typ) {
145  // No special cases are required for !w1.Member(), !w2.Member(), or
146  // w2 == Zero(), since Minus and Divide will already return NoWeight()
147  // in these cases.
148 
149  // For a right-semiring, by the definition of Divide, we are looking for
150  // z = x / y such that (x / y) * y = x.
151  // Let <x1, x2> = x, <y1, y2> = y, <z1, z2> = z.
152  // <z1, z2> * <y1, y2> = <x1, x2>.
153  // By the definition of Times:
154  // z1 * y1 = x1 and
155  // z1 * y2 + z2 * y1 = x2.
156  // So z1 = x1 / y1, and
157  // z2 * y2 = x2 - z1 * y2
158  // z2 = (x2 - z1 * y2) / y2.
159  // The left-semiring case is symmetric. The commutative case allows
160  // additional simplification to
161  // z2 = z1 * (x2 / x1 - y2 / y1) if x1 != 0
162  // z2 = x2 / y1 if x1 == 0, but this requires testing against 0
163  // with ApproxEquals. We just use the right-semiring result in
164  // this case.
165  const auto w11 = w1.Value1();
166  const auto w12 = w1.Value2();
167  const auto w21 = w2.Value1();
168  const auto w22 = w2.Value2();
169  const W1 q1 = Divide(w11, w21, typ);
170  if (typ == DIVIDE_LEFT) {
171  const W2 q2 = Divide(Minus(w12, Times(w22, q1)), w21, typ);
172  return ExpectationWeight<W1, W2>(q1, q2);
173  } else {
174  // Right or commutative semiring.
175  const W2 q2 = Divide(Minus(w12, Times(q1, w22)), w21, typ);
176  return ExpectationWeight<W1, W2>(q1, q2);
177  }
178 }
179 
180 // Specialization for expectation weight.
181 template <class W1, class W2>
182 class Adder<ExpectationWeight<W1, W2>> {
183  public:
185 
186  Adder() = default;
187 
188  explicit Adder(Weight w) : adder1_(w.Value1()), adder2_(w.Value2()) {}
189 
190  Weight Add(const Weight &w) {
191  adder1_.Add(w.Value1());
192  adder2_.Add(w.Value2());
193  return Sum();
194  }
195 
196  Weight Sum() const { return Weight(adder1_.Sum(), adder2_.Sum()); }
197 
198  void Reset(Weight w = Weight::Zero()) {
199  adder1_.Reset(w.Value1());
200  adder2_.Reset(w.Value2());
201  }
202 
203  private:
204  Adder<W1> adder1_;
205  Adder<W2> adder2_;
206 };
207 
208 // This function object generates weights by calling the underlying generators
209 // for the template weight types, like all other pair weight types. This is
210 // intended primarily for testing.
211 template <class W1, class W2>
213  public:
216 
217  explicit WeightGenerate(uint64_t seed = std::random_device()(),
218  bool allow_zero = true)
219  : generate_(seed, allow_zero) {}
220 
221  Weight operator()() const { return Weight(generate_()); }
222 
223  private:
224  const Generate generate_;
225 };
226 
227 } // namespace fst
228 
229 #endif // FST_EXPECTATION_WEIGHT_H_
static const ExpectationWeight & One()
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:61
static const ExpectationWeight & Zero()
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:64
bool Member() const
Definition: pair-weight.h:75
const W2 & Value2() const
Definition: pair-weight.h:95
constexpr uint64_t kIdempotent
Definition: weight.h:147
ExpectationWeight< typename W1::ReverseWeight, typename W2::ReverseWeight > ReverseWeight
constexpr uint64_t kRightSemiring
Definition: weight.h:139
ExpectationWeight Quantize(float delta=kDelta) const
WeightGenerate(uint64_t seed=std::random_device()(), bool allow_zero=true)
constexpr uint64_t kCommutative
Definition: weight.h:144
ExpectationWeight(const W1 &w1, const W2 &w2)
static constexpr uint64_t Properties()
LogWeightTpl< T > Minus(const LogWeightTpl< T > &w1, const LogWeightTpl< T > &w2)
Definition: float-weight.h:542
ErrorWeight Divide(const ErrorWeight &, const ErrorWeight &)
Definition: error-weight.h:67
ExpectationWeight(const PairWeight< W1, W2 > &weight)
ReverseWeight Reverse() const
DivideType
Definition: weight.h:165
static const std::string & Type()
constexpr uint64_t kLeftSemiring
Definition: weight.h:136
constexpr float kDelta
Definition: weight.h:133
const W1 & Value1() const
Definition: pair-weight.h:93
static const ExpectationWeight & NoWeight()