20 #ifndef FST_TEST_ALGO_TEST_H_ 21 #define FST_TEST_ALGO_TEST_H_ 45 return A(0, 0, arc.weight, arc.nextstate);
75 std::vector<StdArc::StateId> order;
94 using Label =
typename Arc::Label;
107 generate_(std::move(weight_generator)) {}
110 TestRational(T1, T2, T3);
112 TestCompose(T1, T2, T3);
123 VLOG(1) <<
"Check destructive and delayed union are equivalent.";
127 CHECK(Equiv(U1, U2));
131 VLOG(1) <<
"Check destructive and delayed concatenation are equivalent.";
135 CHECK(Equiv(C1, C2));
138 CHECK(Equiv(C3, C2));
142 VLOG(1) <<
"Check destructive and delayed closure* are equivalent.";
146 CHECK(Equiv(C1, C2));
150 VLOG(1) <<
"Check destructive and delayed closure+ are equivalent.";
154 CHECK(Equiv(C1, C2));
158 VLOG(1) <<
"Check union is associative (destructive).";
168 CHECK(Equiv(U1, U4));
172 VLOG(1) <<
"Check union is associative (delayed).";
179 CHECK(Equiv(U2, U4));
183 VLOG(1) <<
"Check union is associative (destructive delayed).";
190 CHECK(Equiv(U1, U4));
194 VLOG(1) <<
"Check concatenation is associative (destructive).";
204 CHECK(Equiv(C1, C4));
208 VLOG(1) <<
"Check concatenation is associative (delayed).";
215 CHECK(Equiv(C2, C4));
219 VLOG(1) <<
"Check concatenation is associative (destructive delayed).";
226 CHECK(Equiv(C1, C4));
230 VLOG(1) <<
"Check concatenation left distributes" 231 <<
" over union (destructive).";
245 CHECK(Equiv(C1, U2));
249 VLOG(1) <<
"Check concatenation right distributes" 250 <<
" over union (destructive).";
263 CHECK(Equiv(C1, U2));
267 VLOG(1) <<
"Check concatenation left distributes over union (delayed).";
275 CHECK(Equiv(C1, U2));
279 VLOG(1) <<
"Check concatenation right distributes over union (delayed).";
287 CHECK(Equiv(C1, U2));
291 VLOG(1) <<
"Check T T* == T+ (destructive).";
304 VLOG(1) <<
"Check T* T == T+ (destructive).";
317 VLOG(1) <<
"Check T T* == T+ (delayed).";
327 VLOG(1) <<
"Check T* T == T+ (delayed).";
340 VLOG(1) <<
"Check destructive and delayed projection are equivalent.";
344 CHECK(Equiv(P1, P2));
348 VLOG(1) <<
"Check destructive and delayed inversion are equivalent.";
352 CHECK(Equiv(I1, I2));
356 VLOG(1) <<
"Check Pi_1(T) = Pi_2(T^-1) (destructive).";
362 CHECK(Equiv(P1, I1));
366 VLOG(1) <<
"Check Pi_2(T) = Pi_1(T^-1) (destructive).";
372 CHECK(Equiv(P1, I1));
376 VLOG(1) <<
"Check Pi_1(T) = Pi_2(T^-1) (delayed).";
380 CHECK(Equiv(P1, P2));
384 VLOG(1) <<
"Check Pi_2(T) = Pi_1(T^-1) (delayed).";
388 CHECK(Equiv(P1, P2));
392 VLOG(1) <<
"Check destructive relabeling";
393 static const int kNumLabels = 10;
395 std::vector<Label> labelset(kNumLabels);
396 for (
size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
397 for (
size_t i = 0; i < kNumLabels; ++i) {
400 std::uniform_int_distribution<>(0, kNumLabels - 1)(rand_);
401 swap(labelset[i], labelset[index]);
404 std::vector<std::pair<Label, Label>> ipairs1(kNumLabels);
405 std::vector<std::pair<Label, Label>> opairs1(kNumLabels);
406 for (
size_t i = 0; i < kNumLabels; ++i) {
407 ipairs1[i] = std::make_pair(i, labelset[i]);
408 opairs1[i] = std::make_pair(labelset[i], i);
413 std::vector<std::pair<Label, Label>> ipairs2(kNumLabels);
414 std::vector<std::pair<Label, Label>> opairs2(kNumLabels);
415 for (
size_t i = 0; i < kNumLabels; ++i) {
416 ipairs2[i] = std::make_pair(labelset[i], i);
417 opairs2[i] = std::make_pair(i, labelset[i]);
422 VLOG(1) <<
"Check on-the-fly relabeling";
426 CHECK(Equiv(RRdelay, T));
430 VLOG(1) <<
"Check encoding/decoding (destructive).";
432 uint8_t encode_props = 0;
433 if (std::bernoulli_distribution(.5)(rand_)) {
436 if (std::bernoulli_distribution(.5)(rand_)) {
446 VLOG(1) <<
"Check encoding/decoding (delayed).";
447 uint8_t encode_props = 0;
448 if (std::bernoulli_distribution(.5)(rand_)) {
451 if (std::bernoulli_distribution(.5)(rand_)) {
462 VLOG(1) <<
"Check gallic mappers (constructive).";
468 ArcMap(G, &F, from_mapper);
473 VLOG(1) <<
"Check gallic mappers (delayed).";
496 VLOG(1) <<
"Check composition is associative.";
502 CHECK(Equiv(C2, C4));
506 VLOG(1) <<
"Check composition left distributes over union.";
514 CHECK(Equiv(C1, U2));
518 VLOG(1) <<
"Check composition right distributes over union.";
526 CHECK(Equiv(C1, U2));
537 VLOG(1) <<
"Check intersection is commutative.";
540 CHECK(Equiv(I1, I2));
544 VLOG(1) <<
"Check all epsilon filters leads to equivalent results.";
552 CHECK(Equiv(C1, C2));
553 CHECK(Equiv(C1, C3));
560 CHECK(Equiv(C1, C4));
563 CHECK(Equiv(C1, C5));
570 CHECK(Equiv(C1, C6));
575 VLOG(1) <<
"Check look-ahead filters lead to equivalent results.";
579 CHECK(Equiv(C1, C2));
589 VLOG(1) <<
"Check arc sorted Fst is equivalent to its input.";
596 VLOG(1) <<
"Check destructive and delayed arcsort are equivalent.";
600 CHECK(Equiv(S1, S2));
604 VLOG(1) <<
"Check ilabel sorting vs. olabel sorting with inversions.";
611 CHECK(Equiv(S1, S2));
615 VLOG(1) <<
"Check topologically sorted Fst is equivalent to its input.";
622 VLOG(1) <<
"Check reverse(reverse(T)) = T";
623 for (
int i = 0; i < 2; ++i) {
626 bool require_superinitial = i == 1;
627 Reverse(T, &R1, require_superinitial);
628 Reverse(R1, &R2, require_superinitial);
635 void TestOptimize(
const Fst<Arc> &T) {
637 uint64_t wprops = Weight::Properties();
642 VLOG(1) <<
"Check connected FST is equivalent to its input.";
650 VLOG(1) <<
"Check epsilon-removed FST is equivalent to its input.";
655 VLOG(1) <<
"Check destructive and delayed epsilon removal" 656 <<
"are equivalent.";
658 CHECK(Equiv(R1, R2));
660 VLOG(1) <<
"Check an FST with a large proportion" 661 <<
" of epsilon transitions:";
668 Arc arc(1, 1, Weight::One(), V.
AddState());
670 V.
SetFinal(arc.nextstate, Weight::One());
674 std::vector<Weight> d;
688 if ((wprops & kSemiring) == kSemiring && tprops &
kAcyclic) {
689 VLOG(1) <<
"Check determinized FSA is equivalent to its input.";
694 VLOG(1) <<
"Check determinized FST is equivalent to its input.";
702 VLOG(1) <<
"Check pruning in determinization";
704 const Weight threshold = generate_();
709 CHECK(PruneEquiv(A, P, threshold));
712 if ((wprops &
kPath) == kPath) {
713 VLOG(1) <<
"Check min-determinization";
717 std::vector<std::pair<Label, Label>> ipairs, opairs;
718 ipairs.push_back(std::pair<Label, Label>(0, 1));
726 CHECK(MinRelated(M, R));
731 VLOG(1) <<
"Check size(min(det(A))) <= size(det(A))" 732 <<
" and min(det(A)) equiv det(A)";
741 if (n && (wprops & kIdempotent) == kIdempotent &&
743 VLOG(1) <<
"Check that Revuz's algorithm leads to the" 744 <<
" same number of states as Brozozowski's algorithm";
762 if ((wprops & kSemiring) == kSemiring && tprops &
kAcyclic) {
763 VLOG(1) <<
"Check disambiguated FSA is equivalent to its input.";
768 VLOG(1) <<
"Check disambiguated FSA is unambiguous";
769 CHECK(Unambiguous(D));
786 VLOG(1) <<
"Check reweight(T) equiv T";
787 std::vector<Weight> potential;
790 while (potential.size() < RI.
NumStates()) {
791 potential.push_back(generate_());
801 if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
802 VLOG(1) <<
"Check pushed FST is equivalent to input FST.";
821 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1,
kPushLabels);
835 VLOG(1) <<
"Check pruning algorithm";
837 VLOG(1) <<
"Check equiv. of constructive and destructive algorithms";
838 const Weight threshold = generate_();
840 Prune(&P1, threshold);
842 Prune(T, &P2, threshold);
843 CHECK(Equiv(P1, P2));
847 VLOG(1) <<
"Check prune(reverse) equiv reverse(prune)";
848 const Weight threshold = generate_();
852 Prune(&P1, threshold);
854 Prune(&R, threshold.Reverse());
856 CHECK(Equiv(P1, P2));
859 VLOG(1) <<
"Check: ShortestDistance(A - prune(A))" 860 <<
" > ShortestDistance(A) times Threshold";
861 const Weight threshold = generate_();
863 Prune(A, &P, threshold);
864 CHECK(PruneEquiv(A, P, threshold));
868 if (tprops & kAcyclic) {
869 VLOG(1) <<
"Check synchronize(T) equiv T";
876 void TestSearch(
const Fst<Arc> &T) {
878 uint64_t wprops = Weight::Properties();
884 VLOG(1) <<
"Check 1-best weight.";
893 VLOG(1) <<
"Check n-best weights";
897 const int nshortest = std::uniform_int_distribution<>(
898 0, kNumRandomShortestPaths + 1)(rand_);
901 false, Weight::Zero(), kNumShortestStates,
903 std::vector<Weight> distance;
908 for (; !piter.
Done(); piter.
Next()) {
910 Weight nsum = s < distance.size()
932 VLOG(1) <<
"Check FSTs for sanity (including property bits).";
940 return RandEquivalent(fst1, fst2, kNumRandomPaths, opts, kTestDelta, seed_);
951 return Equiv(lfst1, lfst2);
958 bool MinRelated(
const Fst<A> &fst1,
const Fst<A> &fst2) {
963 if (!Equiv(P1, P2)) {
964 LOG(ERROR) <<
"Inputs not equivalent";
974 for (ssize_t n = 0; n < kNumRandomPaths; ++n) {
980 Compose(paths1, path, &paths2);
983 LOG(ERROR) <<
"Sums not equivalent: " << sum1 <<
" " << sum2;
993 VLOG(1) <<
"Check FSTs for sanity (including property bits).";
1007 std::mt19937_64 rand_;
1017 static constexpr
int kRandomPathLength = 25;
1019 static constexpr
int kNumRandomPaths = 100;
1021 static constexpr
int kNumRandomShortestPaths = 100;
1023 static constexpr
int kNumShortestStates = 10000;
1025 static constexpr
float kTestDelta = .05;
1034 template <
class Arc>
1038 const Fst<Arc> &univ_fsa, uint64_t seed) {}
1055 const Fst<Arc> &univ_fsa, uint64_t seed)
1056 : zero_fsa_(zero_fsa),
1058 univ_fsa_(univ_fsa),
1062 TestRational(A1, A2, A3);
1063 TestIntersect(A1, A2, A3);
1072 VLOG(1) <<
"Check the union contains its arguments (destructive).";
1076 CHECK(Subset(A1, U));
1077 CHECK(Subset(A2, U));
1081 VLOG(1) <<
"Check the union contains its arguments (delayed).";
1084 CHECK(Subset(A1, U));
1085 CHECK(Subset(A2, U));
1089 VLOG(1) <<
"Check if A^n c A* (destructive).";
1091 const int n = std::uniform_int_distribution<>(0, 4)(rand_);
1092 for (
int i = 0; i < n; ++i)
Concat(&C, A1);
1096 CHECK(Subset(C, S));
1100 VLOG(1) <<
"Check if A^n c A* (delayed).";
1101 const int n = std::uniform_int_distribution<>(0, 4)(rand_);
1102 std::unique_ptr<Fst<Arc>> C = std::make_unique<VectorFst<Arc>>(one_fsa_);
1103 for (
int i = 0; i < n; ++i) {
1104 C = std::make_unique<ConcatFst<Arc>>(*C, A1);
1107 CHECK(Subset(*C, S));
1125 VLOG(1) <<
"Check the intersection is contained in its arguments.";
1127 CHECK(Subset(I1, S1));
1128 CHECK(Subset(I1, S2));
1132 VLOG(1) <<
"Check union distributes over intersection.";
1141 CHECK(Equiv(U1, I2));
1146 Complement(S1, &C1);
1147 Complement(S2, &C2);
1152 VLOG(1) <<
"Check S U S' = Sigma*";
1154 CHECK(Equiv(U, univ_fsa_));
1158 VLOG(1) <<
"Check S n S' = {}";
1160 CHECK(Equiv(I, zero_fsa_));
1164 VLOG(1) <<
"Check (S1' U S2') == (S1 n S2)'";
1170 CHECK(Equiv(U, C3));
1174 VLOG(1) <<
"Check (S1' n S2') == (S1 U S2)'";
1180 CHECK(Equiv(I, C3));
1185 void TestOptimize(
const Fst<Arc> &A) {
1187 VLOG(1) <<
"Check determinized FSA is equivalent to its input.";
1193 VLOG(1) <<
"Check disambiguated FSA is equivalent to its input.";
1202 VLOG(1) <<
"Check minimized FSA is equivalent to its input.";
1214 VLOG(1) <<
"Check that Hopcroft's and Revuz's algorithms lead to the" 1215 <<
" same number of states as Brozozowski's algorithm";
1232 VLOG(1) <<
"Check FSAs for sanity (including property bits).";
1257 Union(&ufsa, dfsa2);
1262 CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
1269 VLOG(1) <<
"Check FSAs (incl. property bits) for sanity";
1301 std::mt19937_64 rand_;
1308 template <
class Arc>
1317 : generate_(std::move(generator)), rand_(seed) {
1318 one_fst_.AddState();
1319 one_fst_.SetStart(0);
1320 one_fst_.SetFinal(0);
1322 univ_fst_.AddState();
1323 univ_fst_.SetStart(0);
1324 univ_fst_.SetFinal(0);
1325 for (
int i = 0; i < kNumRandomLabels; ++i) univ_fst_.EmplaceArc(0, i, i, 0);
1328 univ_fst_, generate_));
1330 unweighted_tester_.reset(
1335 RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs,
1336 kNumRandomLabels, kAcyclicProb, generate_,
1341 VLOG(1) <<
"weight type = " << Weight::Type();
1343 for (
int i = 0; i < FST_FLAGS_repeat; ++i) {
1351 weighted_tester_->Test(T1, T2, T3);
1359 ArcMap(&A1, rm_weight_mapper_);
1360 ArcMap(&A2, rm_weight_mapper_);
1361 ArcMap(&A3, rm_weight_mapper_);
1362 unweighted_tester_->Test(A1, A2, A3);
1370 std::mt19937_64 rand_;
1378 std::unique_ptr<WeightedTester<Arc>> weighted_tester_;
1380 std::unique_ptr<UnweightedTester<Arc>> unweighted_tester_;
1384 static constexpr
int kNumRandomStates = 10;
1386 static constexpr
int kNumRandomArcs = 25;
1388 static constexpr
int kNumRandomLabels = 5;
1390 static constexpr
float kAcyclicProb = .25;
1392 static constexpr
int kRandomPathLength = 25;
1394 static constexpr
int kNumRandomPaths = 100;
1401 #endif // FST_TEST_ALGO_TEST_H_
constexpr uint64_t kSemiring
void ArcMap(MutableFst< A > *fst, C *mapper)
void AddArc(StateId s, const Arc &arc) override
bool RandEquivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, int32_t npath, const RandGenOptions< ArcSelector > &opts, float delta=kDelta, uint64_t seed=std::random_device()(), bool *error=nullptr)
void SetStart(StateId s) override
void RmEpsilon(MutableFst< Arc > *fst, std::vector< typename Arc::Weight > *distance, const RmEpsilonOptions< Arc, Queue > &opts)
void Invert(const Fst< Arc > &ifst, MutableFst< Arc > *ofst)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
constexpr uint64_t kOEpsilons
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
constexpr uint8_t kPushLabels
void MakeRandFst(MutableFst< Arc > *fst)
void SetFinal(StateId s, Weight weight=Weight::One()) override
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
ArcMapFst(const Fst< typename ArcMapper::FromArc > &, const ArcMapper &) -> ArcMapFst< typename ArcMapper::FromArc, typename ArcMapper::ToArc, ArcMapper >
void Closure(MutableFst< Arc > *fst, ClosureType closure_type)
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
void Determinize(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DeterminizeOptions< Arc > &opts=DeterminizeOptions< Arc >())
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
typename Arc::Label Label
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
StateId Start() const override
typename Arc::StateId StateId
static const std::string & Type()
bool TopSort(MutableFst< Arc > *fst)
ArcTpl< TropicalWeight > StdArc
constexpr uint64_t kIdempotent
void Connect(MutableFst< Arc > *fst)
constexpr uint64_t kEpsilons
constexpr uint64_t kNotOLabelSorted
typename Arc::StateId StateId
const Arc & Value() const
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
constexpr uint64_t kRightSemiring
void Difference(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const DifferenceOptions &opts=DifferenceOptions())
StateId AddState() override
void Test(const Fst< Arc > &A1, const Fst< Arc > &A2, const Fst< Arc > &A3)
AlgoTester(WeightGenerator generator, uint64_t seed)
constexpr uint8_t kPushWeights
constexpr uint64_t kNotILabelSorted
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
MapSymbolsAction OutputSymbolsAction() const
constexpr uint8_t kEncodeLabels
constexpr uint64_t kNoOEpsilons
void Union(RationalFst< Arc > *fst1, const Fst< Arc > &fst2)
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kOLabelSorted
constexpr uint64_t kNotAcceptor
constexpr uint64_t kAcyclic
constexpr uint64_t kCommutative
constexpr uint64_t kNoEpsilons
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
UnweightedTester(const Fst< Arc > &zero_fsa, const Fst< Arc > &one_fsa, const Fst< Arc > &univ_fsa, uint64_t seed)
StateId NumStates() const override
MapFinalAction FinalAction() const
void LookAheadCompose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst)
A operator()(const A &arc) const
constexpr uint64_t kIDeterministic
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
constexpr uint64_t kIEpsilons
MapSymbolsAction InputSymbolsAction() const
void Concat(MutableFst< Arc > *fst1, const Fst< Arc > &fst2)
void Project(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, ProjectType project_type)
bool Verify(const Fst< Arc > &fst, bool allow_negative_labels=false)
constexpr uint8_t kEncodeWeights
constexpr uint64_t kILabelSorted
constexpr uint64_t kFstProperties
void Minimize(MutableFst< Arc > *fst, MutableFst< Arc > *sfst=nullptr, float delta=kShortestDelta, bool allow_nondet=false)
uint64_t Properties(uint64_t props) const
WeightedTester(uint64_t seed, const Fst< Arc > &zero_fst, const Fst< Arc > &one_fst, const Fst< Arc > &univ_fst, WeightGenerator weight_generator)
typename Arc::Label Label
typename Arc::Weight Weight
void Test(const Fst< Arc > &A1, const Fst< Arc > &A2, const Fst< Arc > &A3)
bool Equivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, float delta=kDelta, bool *error=nullptr)
void Test(const Fst< Arc > &T1, const Fst< Arc > &T2, const Fst< Arc > &T3)
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
UnweightedTester(const Fst< Arc > &zero_fsa, const Fst< Arc > &one_fsa, const Fst< Arc > &univ_fsa, uint64_t seed)
constexpr uint64_t kLeftSemiring
void Disambiguate(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DisambiguateOptions< Arc > &opts=DisambiguateOptions< Arc >())
typename Arc::Weight Weight
constexpr uint64_t kNoIEpsilons
void Decode(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
uint64_t Properties(uint64_t mask, bool test) const override
constexpr uint64_t kAcceptor