6 #ifndef FST_TEST_ALGO_TEST_H_ 7 #define FST_TEST_ALGO_TEST_H_ 26 return A(0, 0, arc.weight, arc.nextstate);
56 std::vector<StdArc::StateId> order;
72 template <
class Arc,
class WeightGenerator>
75 typedef typename Arc::Label
Label;
80 const Fst<Arc> &univ_fst, WeightGenerator *weight_generator)
85 weight_generator_(weight_generator) {}
88 TestRational(T1, T2, T3);
90 TestCompose(T1, T2, T3);
101 VLOG(1) <<
"Check destructive and delayed union are equivalent.";
105 CHECK(Equiv(U1, U2));
109 VLOG(1) <<
"Check destructive and delayed concatenation are equivalent.";
113 CHECK(Equiv(C1, C2));
116 CHECK(Equiv(C3, C2));
120 VLOG(1) <<
"Check destructive and delayed closure* are equivalent.";
124 CHECK(Equiv(C1, C2));
128 VLOG(1) <<
"Check destructive and delayed closure+ are equivalent.";
132 CHECK(Equiv(C1, C2));
136 VLOG(1) <<
"Check union is associative (destructive).";
146 CHECK(Equiv(U1, U4));
150 VLOG(1) <<
"Check union is associative (delayed).";
157 CHECK(Equiv(U2, U4));
161 VLOG(1) <<
"Check union is associative (destructive delayed).";
168 CHECK(Equiv(U1, U4));
172 VLOG(1) <<
"Check concatenation is associative (destructive).";
182 CHECK(Equiv(C1, C4));
186 VLOG(1) <<
"Check concatenation is associative (delayed).";
193 CHECK(Equiv(C2, C4));
197 VLOG(1) <<
"Check concatenation is associative (destructive delayed).";
204 CHECK(Equiv(C1, C4));
208 VLOG(1) <<
"Check concatenation left distributes" 209 <<
" over union (destructive).";
223 CHECK(Equiv(C1, U2));
227 VLOG(1) <<
"Check concatenation right distributes" 228 <<
" over union (destructive).";
241 CHECK(Equiv(C1, U2));
245 VLOG(1) <<
"Check concatenation left distributes over union (delayed).";
253 CHECK(Equiv(C1, U2));
257 VLOG(1) <<
"Check concatenation right distributes over union (delayed).";
265 CHECK(Equiv(C1, U2));
269 VLOG(1) <<
"Check T T* == T+ (destructive).";
282 VLOG(1) <<
"Check T* T == T+ (destructive).";
295 VLOG(1) <<
"Check T T* == T+ (delayed).";
305 VLOG(1) <<
"Check T* T == T+ (delayed).";
318 VLOG(1) <<
"Check destructive and delayed projection are equivalent.";
322 CHECK(Equiv(P1, P2));
326 VLOG(1) <<
"Check destructive and delayed inversion are equivalent.";
330 CHECK(Equiv(I1, I2));
334 VLOG(1) <<
"Check Pi_1(T) = Pi_2(T^-1) (destructive).";
340 CHECK(Equiv(P1, I1));
344 VLOG(1) <<
"Check Pi_2(T) = Pi_1(T^-1) (destructive).";
350 CHECK(Equiv(P1, I1));
354 VLOG(1) <<
"Check Pi_1(T) = Pi_2(T^-1) (delayed).";
358 CHECK(Equiv(P1, P2));
362 VLOG(1) <<
"Check Pi_2(T) = Pi_1(T^-1) (delayed).";
366 CHECK(Equiv(P1, P2));
370 VLOG(1) <<
"Check destructive relabeling";
371 static const int kNumLabels = 10;
373 std::vector<Label> labelset(kNumLabels);
374 for (
size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
375 for (
size_t i = 0; i < kNumLabels; ++i) {
377 swap(labelset[i], labelset[rand() % kNumLabels]);
380 std::vector<std::pair<Label, Label>> ipairs1(kNumLabels);
381 std::vector<std::pair<Label, Label>> opairs1(kNumLabels);
382 for (
size_t i = 0; i < kNumLabels; ++i) {
383 ipairs1[i] = std::make_pair(i, labelset[i]);
384 opairs1[i] = std::make_pair(labelset[i], i);
389 std::vector<std::pair<Label, Label>> ipairs2(kNumLabels);
390 std::vector<std::pair<Label, Label>> opairs2(kNumLabels);
391 for (
size_t i = 0; i < kNumLabels; ++i) {
392 ipairs2[i] = std::make_pair(labelset[i], i);
393 opairs2[i] = std::make_pair(i, labelset[i]);
398 VLOG(1) <<
"Check on-the-fly relabeling";
402 CHECK(Equiv(RRdelay, T));
406 VLOG(1) <<
"Check encoding/decoding (destructive).";
409 if (rand() % 2) encode_props |= kEncodeLabels;
410 if (rand() % 2) encode_props |= kEncodeWeights;
418 VLOG(1) <<
"Check encoding/decoding (delayed).";
420 if (rand() % 2) encode_props |= kEncodeLabels;
421 if (rand() % 2) encode_props |= kEncodeWeights;
430 VLOG(1) <<
"Check gallic mappers (constructive).";
436 ArcMap(G, &F, from_mapper);
441 VLOG(1) <<
"Check gallic mappers (delayed).";
466 VLOG(1) <<
"Check composition is associative.";
472 CHECK(Equiv(C2, C4));
476 VLOG(1) <<
"Check composition left distributes over union.";
484 CHECK(Equiv(C1, U2));
488 VLOG(1) <<
"Check composition right distributes over union.";
496 CHECK(Equiv(C1, U2));
507 VLOG(1) <<
"Check intersection is commutative.";
510 CHECK(Equiv(I1, I2));
514 VLOG(1) <<
"Check all epsilon filters leads to equivalent results.";
522 CHECK(Equiv(C1, C2));
523 CHECK(Equiv(C1, C3));
530 CHECK(Equiv(C1, C4));
537 CHECK(Equiv(C1, C5));
542 VLOG(1) <<
"Check look-ahead filters lead to equivalent results.";
546 CHECK(Equiv(C1, C2));
556 VLOG(1) <<
"Check arc sorted Fst is equivalent to its input.";
563 VLOG(1) <<
"Check destructive and delayed arcsort are equivalent.";
567 CHECK(Equiv(S1, S2));
571 VLOG(1) <<
"Check ilabel sorting vs. olabel sorting with inversions.";
578 CHECK(Equiv(S1, S2));
582 VLOG(1) <<
"Check topologically sorted Fst is equivalent to its input.";
589 VLOG(1) <<
"Check reverse(reverse(T)) = T";
590 for (
int i = 0; i < 2; ++i) {
593 bool require_superinitial = i == 1;
594 Reverse(T, &R1, require_superinitial);
595 Reverse(R1, &R2, require_superinitial);
602 void TestOptimize(
const Fst<Arc> &T) {
604 uint64 wprops = Weight::Properties();
609 VLOG(1) <<
"Check connected FST is equivalent to its input.";
617 VLOG(1) <<
"Check epsilon-removed FST is equivalent to its input.";
622 VLOG(1) <<
"Check destructive and delayed epsilon removal" 623 <<
"are equivalent.";
625 CHECK(Equiv(R1, R2));
627 VLOG(1) <<
"Check an FST with a large proportion" 628 <<
" of epsilon transitions:";
635 Arc arc(1, 1, Weight::One(), V.
AddState());
637 V.
SetFinal(arc.nextstate, Weight::One());
641 std::vector<Weight> d;
643 Weight w = U.
Start() < d.size() ? d[U.
Start()] : Weight::Zero();
647 Weight w1 = U1.
Start() < d.size() ? d[U1.
Start()] : Weight::Zero();
651 Weight w2 = U2.
Start() < d.size() ? d[U2.
Start()] : Weight::Zero();
655 if ((wprops & kSemiring) == kSemiring && tprops &
kAcyclic) {
656 VLOG(1) <<
"Check determinized FSA is equivalent to its input.";
661 VLOG(1) <<
"Check determinized FST is equivalent to its input.";
669 VLOG(1) <<
"Check pruning in determinization";
671 Weight threshold = (*weight_generator_)();
676 CHECK(PruneEquiv(A, P, threshold));
679 if ((wprops &
kPath) == kPath) {
680 VLOG(1) <<
"Check min-determinization";
684 std::vector<std::pair<Label, Label>> ipairs, opairs;
685 ipairs.push_back(std::pair<Label, Label>(0, 1));
693 CHECK(MinRelated(M, R));
698 VLOG(1) <<
"Check size(min(det(A))) <= size(det(A))" 699 <<
" and min(det(A)) equiv det(A)";
708 if (n && (wprops & kIdempotent) == kIdempotent &&
710 VLOG(1) <<
"Check that Revuz's algorithm leads to the" 711 <<
" same number of states as Brozozowski's algorithm";
729 if ((wprops & kSemiring) == kSemiring && tprops &
kAcyclic) {
730 VLOG(1) <<
"Check disambiguated FSA is equivalent to its input.";
735 VLOG(1) <<
"Check disambiguated FSA is unambiguous";
736 CHECK(Unambiguous(D));
753 VLOG(1) <<
"Check reweight(T) equiv T";
754 std::vector<Weight> potential;
757 while (potential.size() < RI.
NumStates())
758 potential.push_back((*weight_generator_)());
767 if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
768 VLOG(1) <<
"Check pushed FST is equivalent to input FST.";
787 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1,
kPushLabels);
800 VLOG(1) <<
"Check pruning algorithm";
802 VLOG(1) <<
"Check equiv. of constructive and destructive algorithms";
803 Weight thresold = (*weight_generator_)();
805 Prune(&P1, thresold);
807 Prune(T, &P2, thresold);
808 CHECK(Equiv(P1, P2));
812 VLOG(1) <<
"Check prune(reverse) equiv reverse(prune)";
813 Weight thresold = (*weight_generator_)();
817 Prune(&P1, thresold);
819 Prune(&R, thresold.Reverse());
821 CHECK(Equiv(P1, P2));
824 VLOG(1) <<
"Check: ShortestDistance(A - prune(A))" 825 <<
" > ShortestDistance(A) times Threshold";
826 Weight threshold = (*weight_generator_)();
828 Prune(A, &P, threshold);
829 CHECK(PruneEquiv(A, P, threshold));
832 if (tprops & kAcyclic) {
833 VLOG(1) <<
"Check synchronize(T) equiv T";
840 void TestSearch(
const Fst<Arc> &T) {
841 uint64 wprops = Weight::Properties();
847 VLOG(1) <<
"Check 1-best weight.";
856 VLOG(1) <<
"Check n-best weights";
860 int nshortest = rand() % kNumRandomShortestPaths + 2;
863 false, Weight::Zero(), kNumShortestStates,
865 std::vector<Weight> distance;
867 StateId pstart = paths.
Start();
870 for (; !piter.
Done(); piter.
Next()) {
871 StateId s = piter.
Value().nextstate;
872 Weight nsum = s < distance.size()
893 VLOG(1) <<
"Check FSTs for sanity (including property bits).";
901 return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts);
912 return Equiv(lfst1, lfst2);
919 bool MinRelated(
const Fst<A> &fst1,
const Fst<A> &fst2) {
924 if (!Equiv(P1, P2)) {
925 LOG(ERROR) <<
"Inputs not equivalent";
935 for (ssize_t n = 0; n < kNumRandomPaths; ++n) {
941 Compose(paths1, path, &paths2);
944 LOG(ERROR) <<
"Sums not equivalent: " << sum1 <<
" " << sum2;
954 bool PruneEquiv(
const Fst<A> &fst,
const Fst<A> &pfst, Weight threshold) {
955 VLOG(1) <<
"Check FSTs for sanity (including property bits).";
976 WeightGenerator *weight_generator_;
978 static const int kRandomPathLength;
980 static const int kNumRandomPaths;
982 static const int kNumRandomShortestPaths;
984 static const int kNumShortestStates;
986 static const float kTestDelta;
992 template <
class A,
class WG>
995 template <
class A,
class WG>
998 template <
class A,
class WG>
1001 template <
class A,
class WG>
1004 template <
class A,
class WG>
1010 template <
class Arc>
1032 : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {}
1035 TestRational(A1, A2, A3);
1036 TestIntersect(A1, A2, A3);
1045 VLOG(1) <<
"Check the union contains its arguments (destructive).";
1049 CHECK(Subset(A1, U));
1050 CHECK(Subset(A2, U));
1054 VLOG(1) <<
"Check the union contains its arguments (delayed).";
1057 CHECK(Subset(A1, U));
1058 CHECK(Subset(A2, U));
1062 VLOG(1) <<
"Check if A^n c A* (destructive).";
1065 for (
int i = 0; i < n; ++i)
Concat(&C, A1);
1069 CHECK(Subset(C, S));
1073 VLOG(1) <<
"Check if A^n c A* (delayed).";
1076 for (
int i = 0; i < n; ++i) {
1082 CHECK(Subset(*C, S));
1101 VLOG(1) <<
"Check the intersection is contained in its arguments.";
1103 CHECK(Subset(I1, S1));
1104 CHECK(Subset(I1, S2));
1108 VLOG(1) <<
"Check union distributes over intersection.";
1117 CHECK(Equiv(U1, I2));
1122 Complement(S1, &C1);
1123 Complement(S2, &C2);
1128 VLOG(1) <<
"Check S U S' = Sigma*";
1130 CHECK(Equiv(U, univ_fsa_));
1134 VLOG(1) <<
"Check S n S' = {}";
1136 CHECK(Equiv(I, zero_fsa_));
1140 VLOG(1) <<
"Check (S1' U S2') == (S1 n S2)'";
1146 CHECK(Equiv(U, C3));
1150 VLOG(1) <<
"Check (S1' n S2') == (S1 U S2)'";
1156 CHECK(Equiv(I, C3));
1161 void TestOptimize(
const Fst<Arc> &A) {
1163 VLOG(1) <<
"Check determinized FSA is equivalent to its input.";
1169 VLOG(1) <<
"Check disambiguated FSA is equivalent to its input.";
1178 VLOG(1) <<
"Check minimized FSA is equivalent to its input.";
1190 VLOG(1) <<
"Check that Hopcroft's and Revuz's algorithms lead to the" 1191 <<
" same number of states as Brozozowski's algorithm";
1208 VLOG(1) <<
"Check FSAs for sanity (including property bits).";
1233 Union(&ufsa, dfsa2);
1238 CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
1245 VLOG(1) <<
"Check FSAs (incl. property bits) for sanity";
1284 template <
class Arc,
class WeightGenerator>
1292 : weight_generator_(generator) {
1293 one_fst_.AddState();
1294 one_fst_.SetStart(0);
1295 one_fst_.SetFinal(0, Weight::One());
1297 univ_fst_.AddState();
1298 univ_fst_.SetStart(0);
1299 univ_fst_.SetFinal(0, Weight::One());
1300 for (
int i = 0; i < kNumRandomLabels; ++i)
1301 univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0));
1304 seed, zero_fst_, one_fst_, univ_fst_, &weight_generator_);
1306 unweighted_tester_ =
1311 delete weighted_tester_;
1312 delete unweighted_tester_;
1316 RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs,
1317 kNumRandomLabels, kAcyclicProb,
1318 &weight_generator_, fst);
1322 VLOG(1) <<
"weight type = " << Weight::Type();
1324 for (
int i = 0; i < FLAGS_repeat; ++i) {
1332 weighted_tester_->Test(T1, T2, T3);
1340 ArcMap(&A1, rm_weight_mapper_);
1341 ArcMap(&A2, rm_weight_mapper_);
1342 ArcMap(&A3, rm_weight_mapper_);
1343 unweighted_tester_->Test(A1, A2, A3);
1349 WeightGenerator weight_generator_;
1370 static const int kNumRandomStates;
1373 static const int kNumRandomArcs;
1376 static const int kNumRandomLabels;
1379 static const float kAcyclicProb;
1382 static const int kRandomPathLength;
1385 static const int kNumRandomPaths;
1391 template <
class A,
class G>
1394 template <
class A,
class G>
1397 template <
class A,
class G>
1400 template <
class A,
class G>
1403 template <
class A,
class G>
1406 template <
class A,
class G>
1411 #endif // FST_TEST_ALGO_TEST_H_
constexpr uint64 kNoEpsilons
void ArcMap(MutableFst< A > *fst, C *mapper)
void AddArc(StateId s, const Arc &arc) override
void SetStart(StateId s) override
void RmEpsilon(MutableFst< Arc > *fst, std::vector< typename Arc::Weight > *distance, const RmEpsilonOptions< Arc, Queue > &opts)
void Invert(const Fst< Arc > &ifst, MutableFst< Arc > *ofst)
constexpr uint64 kNotOLabelSorted
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
void SetFinal(StateId s, Weight weight) override
void Closure(MutableFst< Arc > *fst, ClosureType closure_type)
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
constexpr uint64 kRightSemiring
void Determinize(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DeterminizeOptions< Arc > &opts=DeterminizeOptions< Arc >())
bool RandEquivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, int32 num_paths, float delta, const RandGenOptions< ArcSelector > &opts, bool *error=nullptr)
constexpr uint64 kILabelSorted
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
StateId Start() const override
UnweightedTester(const Fst< Arc > &zero_fsa, const Fst< Arc > &one_fsa, const Fst< Arc > &univ_fsa)
ExpectationWeight< X1, X2 > Times(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
bool TopSort(MutableFst< Arc > *fst)
constexpr uint64 kCommutative
void Connect(MutableFst< Arc > *fst)
constexpr uint64 kNotILabelSorted
constexpr uint64 kLeftSemiring
constexpr uint64 kFstProperties
const Arc & Value() const
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
virtual uint64 Properties(uint64 mask, bool test) const =0
void Map(MutableFst< A > *fst, C *mapper)
constexpr uint64 kEpsilons
void Difference(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const DifferenceOptions &opts=DifferenceOptions())
void MakeRandFst(MutableFst< Arc > *fst)
StateId AddState() override
void Test(const Fst< Arc > &A1, const Fst< Arc > &A2, const Fst< Arc > &A3)
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
MapSymbolsAction OutputSymbolsAction() const
void Union(RationalFst< Arc > *fst1, const Fst< Arc > &fst2)
AlgoTester(WeightGenerator generator, int seed)
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64 kSemiring
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
constexpr uint64 kOLabelSorted
constexpr uint64 kIEpsilons
constexpr uint64 kIdempotent
StateId NumStates() const override
MapFinalAction FinalAction() const
void LookAheadCompose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst)
constexpr uint64 kIDeterministic
A operator()(const A &arc) const
ExpectationWeight< X1, X2 > Plus(const ExpectationWeight< X1, X2 > &w1, const ExpectationWeight< X1, X2 > &w2)
WeightedTester(time_t seed, const Fst< Arc > &zero_fst, const Fst< Arc > &one_fst, const Fst< Arc > &univ_fst, WeightGenerator *weight_generator)
constexpr uint64 kAcceptor
void Test(const Fst< Arc > &T1, const Fst< Arc > &T2, const Fst< Arc > &T3)
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
MapSymbolsAction InputSymbolsAction() const
constexpr uint64 kNoOEpsilons
void Concat(MutableFst< Arc > *fst1, const Fst< Arc > &fst2)
void Project(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, ProjectType project_type)
bool Verify(const Fst< Arc > &fst, bool allow_negative_labels=false)
UnweightedTester(const Fst< Arc > &zero_fsa, const Fst< Arc > &one_fsa, const Fst< Arc > &univ_fsa)
void Minimize(MutableFst< Arc > *fst, MutableFst< Arc > *sfst=nullptr, float delta=kShortestDelta, bool allow_nondet=false)
constexpr uint64 kNotAcceptor
constexpr uint64 kAcyclic
void Test(const Fst< Arc > &A1, const Fst< Arc > &A2, const Fst< Arc > &A3)
static const string & Type()
constexpr bool ApproxEqual(const FloatWeightTpl< T > &w1, const FloatWeightTpl< T > &w2, float delta=kDelta)
bool Equivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, float delta=kDelta, bool *error=nullptr)
constexpr uint32 kPushWeights
uint64 Properties(uint64 props) const
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
uint64 Properties(uint64 mask, bool test) const override
void Disambiguate(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DisambiguateOptions< Arc > &opts=DisambiguateOptions< Arc >())
constexpr uint64 kOEpsilons
void Decode(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
constexpr uint32 kPushLabels
constexpr uint64 kNoIEpsilons