20 #ifndef FST_TEST_ALGO_TEST_H_ 21 #define FST_TEST_ALGO_TEST_H_ 23 #include <sys/types.h> 96 return A(0, 0, arc.weight, arc.nextstate);
126 std::vector<StdArc::StateId> order;
158 generate_(std::move(weight_generator)) {}
161 TestRational(T1, T2, T3);
163 TestCompose(T1, T2, T3);
174 VLOG(1) <<
"Check destructive and delayed union are equivalent.";
178 CHECK(Equiv(U1, U2));
182 VLOG(1) <<
"Check destructive and delayed concatenation are equivalent.";
186 CHECK(Equiv(C1, C2));
189 CHECK(Equiv(C3, C2));
193 VLOG(1) <<
"Check destructive and delayed closure* are equivalent.";
197 CHECK(Equiv(C1, C2));
201 VLOG(1) <<
"Check destructive and delayed closure+ are equivalent.";
205 CHECK(Equiv(C1, C2));
209 VLOG(1) <<
"Check union is associative (destructive).";
219 CHECK(Equiv(U1, U4));
223 VLOG(1) <<
"Check union is associative (delayed).";
230 CHECK(Equiv(U2, U4));
234 VLOG(1) <<
"Check union is associative (destructive delayed).";
241 CHECK(Equiv(U1, U4));
245 VLOG(1) <<
"Check concatenation is associative (destructive).";
255 CHECK(Equiv(C1, C4));
259 VLOG(1) <<
"Check concatenation is associative (delayed).";
266 CHECK(Equiv(C2, C4));
270 VLOG(1) <<
"Check concatenation is associative (destructive delayed).";
277 CHECK(Equiv(C1, C4));
281 VLOG(1) <<
"Check concatenation left distributes" 282 <<
" over union (destructive).";
296 CHECK(Equiv(C1, U2));
300 VLOG(1) <<
"Check concatenation right distributes" 301 <<
" over union (destructive).";
314 CHECK(Equiv(C1, U2));
318 VLOG(1) <<
"Check concatenation left distributes over union (delayed).";
326 CHECK(Equiv(C1, U2));
330 VLOG(1) <<
"Check concatenation right distributes over union (delayed).";
338 CHECK(Equiv(C1, U2));
342 VLOG(1) <<
"Check T T* == T+ (destructive).";
355 VLOG(1) <<
"Check T* T == T+ (destructive).";
368 VLOG(1) <<
"Check T T* == T+ (delayed).";
378 VLOG(1) <<
"Check T* T == T+ (delayed).";
391 VLOG(1) <<
"Check destructive and delayed projection are equivalent.";
395 CHECK(Equiv(P1, P2));
399 VLOG(1) <<
"Check destructive and delayed inversion are equivalent.";
403 CHECK(Equiv(I1, I2));
407 VLOG(1) <<
"Check Pi_1(T) = Pi_2(T^-1) (destructive).";
413 CHECK(Equiv(P1, I1));
417 VLOG(1) <<
"Check Pi_2(T) = Pi_1(T^-1) (destructive).";
423 CHECK(Equiv(P1, I1));
427 VLOG(1) <<
"Check Pi_1(T) = Pi_2(T^-1) (delayed).";
431 CHECK(Equiv(P1, P2));
435 VLOG(1) <<
"Check Pi_2(T) = Pi_1(T^-1) (delayed).";
439 CHECK(Equiv(P1, P2));
443 VLOG(1) <<
"Check destructive relabeling";
444 static const int kNumLabels = 10;
446 std::vector<Label> labelset(kNumLabels);
447 for (
size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
448 for (
size_t i = 0; i < kNumLabels; ++i) {
451 std::uniform_int_distribution<>(0, kNumLabels - 1)(rand_);
452 swap(labelset[i], labelset[index]);
455 std::vector<std::pair<Label, Label>> ipairs1(kNumLabels);
456 std::vector<std::pair<Label, Label>> opairs1(kNumLabels);
457 for (
size_t i = 0; i < kNumLabels; ++i) {
458 ipairs1[i] = std::make_pair(i, labelset[i]);
459 opairs1[i] = std::make_pair(labelset[i], i);
464 std::vector<std::pair<Label, Label>> ipairs2(kNumLabels);
465 std::vector<std::pair<Label, Label>> opairs2(kNumLabels);
466 for (
size_t i = 0; i < kNumLabels; ++i) {
467 ipairs2[i] = std::make_pair(labelset[i], i);
468 opairs2[i] = std::make_pair(i, labelset[i]);
473 VLOG(1) <<
"Check on-the-fly relabeling";
477 CHECK(Equiv(RRdelay, T));
481 VLOG(1) <<
"Check encoding/decoding (destructive).";
483 uint8_t encode_props = 0;
484 if (std::bernoulli_distribution(.5)(rand_)) {
487 if (std::bernoulli_distribution(.5)(rand_)) {
497 VLOG(1) <<
"Check encoding/decoding (delayed).";
498 uint8_t encode_props = 0;
499 if (std::bernoulli_distribution(.5)(rand_)) {
502 if (std::bernoulli_distribution(.5)(rand_)) {
513 VLOG(1) <<
"Check gallic mappers (constructive).";
519 ArcMap(G, &F, from_mapper);
524 VLOG(1) <<
"Check gallic mappers (delayed).";
547 VLOG(1) <<
"Check composition is associative.";
553 CHECK(Equiv(C2, C4));
557 VLOG(1) <<
"Check composition left distributes over union.";
565 CHECK(Equiv(C1, U2));
569 VLOG(1) <<
"Check composition right distributes over union.";
577 CHECK(Equiv(C1, U2));
588 VLOG(1) <<
"Check intersection is commutative.";
591 CHECK(Equiv(I1, I2));
595 VLOG(1) <<
"Check all epsilon filters leads to equivalent results.";
603 CHECK(Equiv(C1, C2));
604 CHECK(Equiv(C1, C3));
611 CHECK(Equiv(C1, C4));
614 CHECK(Equiv(C1, C5));
621 CHECK(Equiv(C1, C6));
626 VLOG(1) <<
"Check look-ahead filters lead to equivalent results.";
630 CHECK(Equiv(C1, C2));
640 VLOG(1) <<
"Check arc sorted Fst is equivalent to its input.";
647 VLOG(1) <<
"Check destructive and delayed arcsort are equivalent.";
651 CHECK(Equiv(S1, S2));
655 VLOG(1) <<
"Check ilabel sorting vs. olabel sorting with inversions.";
662 CHECK(Equiv(S1, S2));
666 VLOG(1) <<
"Check topologically sorted Fst is equivalent to its input.";
673 VLOG(1) <<
"Check reverse(reverse(T)) = T";
674 for (
int i = 0; i < 2; ++i) {
677 bool require_superinitial = i == 1;
678 Reverse(T, &R1, require_superinitial);
679 Reverse(R1, &R2, require_superinitial);
686 void TestOptimize(
const Fst<Arc> &T) {
688 uint64_t wprops = Weight::Properties();
693 VLOG(1) <<
"Check connected FST is equivalent to its input.";
701 VLOG(1) <<
"Check epsilon-removed FST is equivalent to its input.";
706 VLOG(1) <<
"Check destructive and delayed epsilon removal" 707 <<
"are equivalent.";
709 CHECK(Equiv(R1, R2));
711 VLOG(1) <<
"Check an FST with a large proportion" 712 <<
" of epsilon transitions:";
719 Arc arc(1, 1, Weight::One(), V.
AddState());
721 V.
SetFinal(arc.nextstate, Weight::One());
725 std::vector<Weight> d;
739 if ((wprops & kSemiring) == kSemiring && tprops &
kAcyclic) {
740 VLOG(1) <<
"Check determinized FSA is equivalent to its input.";
745 VLOG(1) <<
"Check determinized FST is equivalent to its input.";
753 VLOG(1) <<
"Check pruning in determinization";
755 const Weight threshold = generate_();
760 CHECK(PruneEquiv(A, P, threshold));
763 if ((wprops &
kPath) == kPath) {
764 VLOG(1) <<
"Check min-determinization";
768 std::vector<std::pair<Label, Label>> ipairs, opairs;
769 ipairs.push_back(std::pair<Label, Label>(0, 1));
777 CHECK(MinRelated(M, R));
782 VLOG(1) <<
"Check size(min(det(A))) <= size(det(A))" 783 <<
" and min(det(A)) equiv det(A)";
792 if (n && (wprops & kIdempotent) == kIdempotent &&
794 VLOG(1) <<
"Check that Revuz's algorithm leads to the" 795 <<
" same number of states as Brozozowski's algorithm";
813 if ((wprops & kSemiring) == kSemiring && tprops &
kAcyclic) {
814 VLOG(1) <<
"Check disambiguated FSA is equivalent to its input.";
819 VLOG(1) <<
"Check disambiguated FSA is unambiguous";
820 CHECK(Unambiguous(D));
837 VLOG(1) <<
"Check reweight(T) equiv T";
838 std::vector<Weight> potential;
841 while (potential.size() < RI.
NumStates()) {
842 potential.push_back(generate_());
852 if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
853 VLOG(1) <<
"Check pushed FST is equivalent to input FST.";
872 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1,
kPushLabels);
886 VLOG(1) <<
"Check pruning algorithm";
888 VLOG(1) <<
"Check equiv. of constructive and destructive algorithms";
889 const Weight threshold = generate_();
891 Prune(&P1, threshold);
893 Prune(T, &P2, threshold);
894 CHECK(Equiv(P1, P2));
898 VLOG(1) <<
"Check prune(reverse) equiv reverse(prune)";
899 const Weight threshold = generate_();
903 Prune(&P1, threshold);
905 Prune(&R, threshold.Reverse());
907 CHECK(Equiv(P1, P2));
910 VLOG(1) <<
"Check: ShortestDistance(A - prune(A))" 911 <<
" > ShortestDistance(A) times Threshold";
912 const Weight threshold = generate_();
914 Prune(A, &P, threshold);
915 CHECK(PruneEquiv(A, P, threshold));
919 if (tprops & kAcyclic) {
920 VLOG(1) <<
"Check synchronize(T) equiv T";
927 void TestSearch(
const Fst<Arc> &T) {
929 uint64_t wprops = Weight::Properties();
935 VLOG(1) <<
"Check 1-best weight.";
944 VLOG(1) <<
"Check n-best weights";
948 const int nshortest = std::uniform_int_distribution<>(
949 0, kNumRandomShortestPaths + 1)(rand_);
952 false, Weight::Zero(), kNumShortestStates,
954 std::vector<Weight> distance;
959 for (; !piter.
Done(); piter.
Next()) {
961 Weight nsum = s < distance.size()
983 VLOG(1) <<
"Check FSTs for sanity (including property bits).";
991 return RandEquivalent(fst1, fst2, kNumRandomPaths, opts, kTestDelta, seed_);
1002 return Equiv(lfst1, lfst2);
1009 bool MinRelated(
const Fst<A> &fst1,
const Fst<A> &fst2) {
1014 if (!Equiv(P1, P2)) {
1015 LOG(ERROR) <<
"Inputs not equivalent";
1025 for (ssize_t n = 0; n < kNumRandomPaths; ++n) {
1031 Compose(paths1, path, &paths2);
1034 LOG(ERROR) <<
"Sums not equivalent: " << sum1 <<
" " << sum2;
1044 VLOG(1) <<
"Check FSTs for sanity (including property bits).";
1058 std::mt19937_64 rand_;
1068 static constexpr
int kRandomPathLength = 25;
1070 static constexpr
int kNumRandomPaths = 100;
1072 static constexpr
int kNumRandomShortestPaths = 100;
1074 static constexpr
int kNumShortestStates = 10000;
1076 static constexpr
float kTestDelta = .05;
1085 template <
class Arc>
1089 const Fst<Arc> &univ_fsa, uint64_t seed) {}
1106 const Fst<Arc> &univ_fsa, uint64_t seed)
1107 : zero_fsa_(zero_fsa),
1109 univ_fsa_(univ_fsa),
1113 TestRational(A1, A2, A3);
1114 TestIntersect(A1, A2, A3);
1123 VLOG(1) <<
"Check the union contains its arguments (destructive).";
1127 CHECK(Subset(A1, U));
1128 CHECK(Subset(A2, U));
1132 VLOG(1) <<
"Check the union contains its arguments (delayed).";
1135 CHECK(Subset(A1, U));
1136 CHECK(Subset(A2, U));
1140 VLOG(1) <<
"Check if A^n c A* (destructive).";
1142 const int n = std::uniform_int_distribution<>(0, 4)(rand_);
1143 for (
int i = 0; i < n; ++i)
Concat(&C, A1);
1147 CHECK(Subset(C, S));
1151 VLOG(1) <<
"Check if A^n c A* (delayed).";
1152 const int n = std::uniform_int_distribution<>(0, 4)(rand_);
1153 std::unique_ptr<Fst<Arc>> C = std::make_unique<VectorFst<Arc>>(one_fsa_);
1154 for (
int i = 0; i < n; ++i) {
1155 C = std::make_unique<ConcatFst<Arc>>(*C, A1);
1158 CHECK(Subset(*C, S));
1176 VLOG(1) <<
"Check the intersection is contained in its arguments.";
1178 CHECK(Subset(I1, S1));
1179 CHECK(Subset(I1, S2));
1183 VLOG(1) <<
"Check union distributes over intersection.";
1192 CHECK(Equiv(U1, I2));
1197 Complement(S1, &C1);
1198 Complement(S2, &C2);
1203 VLOG(1) <<
"Check S U S' = Sigma*";
1205 CHECK(Equiv(U, univ_fsa_));
1209 VLOG(1) <<
"Check S n S' = {}";
1211 CHECK(Equiv(I, zero_fsa_));
1215 VLOG(1) <<
"Check (S1' U S2') == (S1 n S2)'";
1221 CHECK(Equiv(U, C3));
1225 VLOG(1) <<
"Check (S1' n S2') == (S1 U S2)'";
1231 CHECK(Equiv(I, C3));
1236 void TestOptimize(
const Fst<Arc> &A) {
1238 VLOG(1) <<
"Check determinized FSA is equivalent to its input.";
1244 VLOG(1) <<
"Check disambiguated FSA is equivalent to its input.";
1253 VLOG(1) <<
"Check minimized FSA is equivalent to its input.";
1265 VLOG(1) <<
"Check that Hopcroft's and Revuz's algorithms lead to the" 1266 <<
" same number of states as Brozozowski's algorithm";
1283 VLOG(1) <<
"Check FSAs for sanity (including property bits).";
1308 Union(&ufsa, dfsa2);
1313 CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
1320 VLOG(1) <<
"Check FSAs (incl. property bits) for sanity";
1352 std::mt19937_64 rand_;
1359 template <
class Arc>
1368 : generate_(std::move(generator)), rand_(seed) {
1369 one_fst_.AddState();
1370 one_fst_.SetStart(0);
1371 one_fst_.SetFinal(0);
1373 univ_fst_.AddState();
1374 univ_fst_.SetStart(0);
1375 univ_fst_.SetFinal(0);
1376 for (
int i = 0; i < kNumRandomLabels; ++i) univ_fst_.EmplaceArc(0, i, i, 0);
1379 univ_fst_, generate_));
1381 unweighted_tester_.reset(
1386 RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs,
1387 kNumRandomLabels, kAcyclicProb, generate_,
1392 VLOG(1) <<
"weight type = " << Weight::Type();
1394 for (
int i = 0; i < FST_FLAGS_repeat; ++i) {
1402 weighted_tester_->Test(T1, T2, T3);
1410 ArcMap(&A1, rm_weight_mapper_);
1411 ArcMap(&A2, rm_weight_mapper_);
1412 ArcMap(&A3, rm_weight_mapper_);
1413 unweighted_tester_->Test(A1, A2, A3);
1421 std::mt19937_64 rand_;
1429 std::unique_ptr<WeightedTester<Arc>> weighted_tester_;
1431 std::unique_ptr<UnweightedTester<Arc>> unweighted_tester_;
1435 static constexpr
int kNumRandomStates = 10;
1437 static constexpr
int kNumRandomArcs = 25;
1439 static constexpr
int kNumRandomLabels = 5;
1441 static constexpr
float kAcyclicProb = .25;
1443 static constexpr
int kRandomPathLength = 25;
1445 static constexpr
int kNumRandomPaths = 100;
1452 #endif // FST_TEST_ALGO_TEST_H_
constexpr uint64_t kSemiring
void ArcMap(MutableFst< A > *fst, C *mapper)
bool RandEquivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, int32_t npath, const RandGenOptions< ArcSelector > &opts, float delta=kDelta, uint64_t seed=std::random_device()(), bool *error=nullptr)
void RmEpsilon(MutableFst< Arc > *fst, std::vector< typename Arc::Weight > *distance, const RmEpsilonOptions< Arc, Queue > &opts)
void Invert(const Fst< Arc > &ifst, MutableFst< Arc > *ofst)
virtual uint64_t Properties(uint64_t mask, bool test) const =0
constexpr uint64_t kOEpsilons
void Encode(MutableFst< Arc > *fst, EncodeMapper< Arc > *mapper)
constexpr uint8_t kPushLabels
void MakeRandFst(MutableFst< Arc > *fst)
ErrorWeight Plus(const ErrorWeight &, const ErrorWeight &)
void Closure(MutableFst< Arc > *fst, ClosureType closure_type)
static void Relabel(MutableFst< Arc > *fst, const LFST &mfst, bool relabel_input)
void Determinize(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DeterminizeOptions< Arc > &opts=DeterminizeOptions< Arc >())
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
typename Arc::Label Label
void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, bool access_only=false)
StateId Start() const override
static const std::string & Type()
typename Arc::StateId StateId
bool TopSort(MutableFst< Arc > *fst)
ArcTpl< TropicalWeight > StdArc
constexpr uint64_t kIdempotent
void Connect(MutableFst< Arc > *fst)
constexpr uint64_t kEpsilons
constexpr uint64_t kNotOLabelSorted
typename Arc::StateId StateId
const Arc & Value() const
uint64_t Properties(uint64_t mask, bool test) const override
void Relabel(MutableFst< Arc > *fst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &ipairs, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &opairs)
void Reverse(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, std::vector< typename Arc::Label > *assignments, MutableFst< RevArc > *ofst)
void Prune(MutableFst< Arc > *fst, const PruneOptions< Arc, ArcFilter > &opts=PruneOptions< Arc, ArcFilter >())
constexpr uint64_t kRightSemiring
void Difference(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const DifferenceOptions &opts=DifferenceOptions())
void Test(const Fst< Arc > &A1, const Fst< Arc > &A2, const Fst< Arc > &A3)
AlgoTester(WeightGenerator generator, uint64_t seed)
constexpr uint8_t kPushWeights
constexpr uint64_t kNotILabelSorted
void RandGen(const Fst< FromArc > &ifst, MutableFst< ToArc > *ofst, const RandGenOptions< Selector > &opts)
MapSymbolsAction OutputSymbolsAction() const
constexpr uint8_t kEncodeLabels
constexpr uint64_t kNoOEpsilons
void Union(RationalFst< Arc > *fst1, const Fst< Arc > &fst2)
void ArcSort(MutableFst< Arc > *fst, Compare comp)
constexpr uint64_t kOLabelSorted
constexpr uint64_t kNotAcceptor
constexpr uint64_t kAcyclic
constexpr uint64_t kCommutative
constexpr uint64_t kNoEpsilons
void Compose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst, const ComposeOptions &opts=ComposeOptions())
StateId NumStates() const override
UnweightedTester(const Fst< Arc > &zero_fsa, const Fst< Arc > &one_fsa, const Fst< Arc > &univ_fsa, uint64_t seed)
MapFinalAction FinalAction() const
void LookAheadCompose(const Fst< Arc > &ifst1, const Fst< Arc > &ifst2, MutableFst< Arc > *ofst)
A operator()(const A &arc) const
constexpr uint64_t kIDeterministic
void ShortestPath(const Fst< Arc > &ifst, const std::vector< std::pair< typename Arc::Label, typename Arc::Label >> &parens, MutableFst< Arc > *ofst, const PdtShortestPathOptions< Arc, Queue > &opts)
constexpr uint64_t kIEpsilons
StateId AddState() override
MapSymbolsAction InputSymbolsAction() const
void Concat(MutableFst< Arc > *fst1, const Fst< Arc > &fst2)
void AddArc(StateId s, const Arc &arc) override
void Project(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, ProjectType project_type)
bool Verify(const Fst< Arc > &fst, bool allow_negative_labels=false)
constexpr uint8_t kEncodeWeights
constexpr uint64_t kILabelSorted
constexpr uint64_t kFstProperties
void Minimize(MutableFst< Arc > *fst, MutableFst< Arc > *sfst=nullptr, float delta=kShortestDelta, bool allow_nondet=false)
uint64_t Properties(uint64_t props) const
void SetFinal(StateId s, Weight weight=Weight::One()) override
WeightedTester(uint64_t seed, const Fst< Arc > &zero_fst, const Fst< Arc > &one_fst, const Fst< Arc > &univ_fst, WeightGenerator weight_generator)
typename Arc::Label Label
typename Arc::Weight Weight
void Test(const Fst< Arc > &A1, const Fst< Arc > &A2, const Fst< Arc > &A3)
bool Equivalent(const Fst< Arc > &fst1, const Fst< Arc > &fst2, float delta=kDelta, bool *error=nullptr)
void Test(const Fst< Arc > &T1, const Fst< Arc > &T2, const Fst< Arc > &T3)
void ShortestDistance(const Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)
ArcMapFst(const Fst< typename ArcMapper::FromArc > &, const ArcMapper &) -> ArcMapFst< typename ArcMapper::FromArc, typename ArcMapper::ToArc, ArcMapper, DefaultCacheStore< typename ArcMapper::ToArc >, PropagateKExpanded::kNo >
UnweightedTester(const Fst< Arc > &zero_fsa, const Fst< Arc > &one_fsa, const Fst< Arc > &univ_fsa, uint64_t seed)
constexpr uint64_t kLeftSemiring
void Disambiguate(const Fst< Arc > &ifst, MutableFst< Arc > *ofst, const DisambiguateOptions< Arc > &opts=DisambiguateOptions< Arc >())
typename Arc::Weight Weight
constexpr uint64_t kNoIEpsilons
void Decode(MutableFst< Arc > *fst, const EncodeMapper< Arc > &mapper)
std::bool_constant<(W::Properties()&kPath)!=0 > IsPath
void Reweight(MutableFst< Arc > *fst, const std::vector< typename Arc::Weight > &potential, ReweightType type)
bool ApproxEqual(const ErrorWeight &, const ErrorWeight &, float)
void SetStart(StateId s) override
constexpr uint64_t kAcceptor