18 #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ 19 #define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ 49 template <
class Iterator>
51 template <
class Label>
52 std::string
JoinLabels(
const std::vector<Label> &labels,
59 typename A::Label
GuessStartOrEnd(std::vector<typename A::Label> *sequence,
60 typename A::Label boundary);
72 typedef typename A::Label
Label;
90 bool Error()
const {
return error_; }
98 bool AddWord(Label word,
const std::vector<Label> &features);
108 bool AddWord(Label word,
const std::vector<Label> &word_features,
109 const std::vector<Label> &possible_output);
159 bool AddWeight(
size_t group,
const std::vector<Label> &input,
160 const std::vector<Label> &output, Weight weight);
175 std::map<Label, std::set<Label>> word_output_map_, word_feat_map_;
176 std::map<Label, std::set<size_t>> feat_groups_;
177 std::vector<std::unique_ptr<FeatureGroupBuilder<A>>> groups_;
178 size_t max_future_size_;
179 Label max_input_label_;
213 num_classes_(num_classes),
215 builder_(isyms, fsyms, osyms) {}
219 bool Error()
const {
return error_; }
222 bool AddWord(Label word,
const std::vector<Label> &features);
235 bool AddWeight(
size_t group,
const std::vector<Label> &input, Label pred,
243 std::vector<Label> empty_;
245 size_t num_classes_, num_groups_;
262 : error_(false), future_size_(future_size), fsyms_(fsyms), osyms_(osyms) {
271 bool Error()
const {
return error_; }
287 bool AddWeight(
const std::vector<Label> &input,
288 const std::vector<Label> &output, Weight weight);
317 int FindFirstMatch(InputOutputLabel label,
int parent,
int *hop)
const;
323 void BuildBackLinks();
329 void PreAccumulateWeights();
332 bool TrieDfs(
const Topology &topology,
int cur,
int target,
333 std::vector<InputOutputLabel> *path)
const;
334 std::string TriePath(
int node,
const Topology &topology)
const;
351 const std::vector<Label> &features) {
353 FSTERROR() <<
"Calling LinearFstDataBuilder<>::AddWord() at error state";
358 LOG(WARNING) <<
"Ignored: adding boundary label: " 367 FSTERROR() <<
"Word label must be > 0; got " << word;
370 if (word > max_input_label_) max_input_label_ = word;
372 if (word_feat_map_.find(word) != word_feat_map_.end()) {
375 <<
" is added twice";
379 std::set<Label> *feats = &word_feat_map_[word];
380 for (
size_t i = 0; i < features.size(); ++i) {
381 Label feat = features[i];
384 FSTERROR() <<
"Feature label must be > 0; got " << feat;
394 Label word,
const std::vector<Label> &word_features,
395 const std::vector<Label> &possible_output) {
397 FSTERROR() <<
"Calling LinearFstDataBuilder<>::AddWord() at error state";
400 if (!
AddWord(word, word_features))
return false;
402 if (possible_output.empty()) {
404 FSTERROR() <<
"Empty possible output constraint; " 405 <<
"use the two-parameter version if no constraint is need.";
408 std::set<Label> *outputs = &word_output_map_[word];
409 for (
size_t i = 0; i < possible_output.size(); ++i) {
410 Label output = possible_output[i];
414 <<
": adding boundary label as possible output: " << output
415 <<
"(start-of-sentence=" 423 FSTERROR() <<
"Output label must be > 0; got " << output;
426 outputs->insert(output);
427 all_output_labels_.Insert(output);
435 FSTERROR() <<
"Calling LinearFstDataBuilder<>::AddGroup() at error state";
438 size_t ret = groups_.size();
440 if (future_size > max_future_size_) max_future_size_ = future_size;
446 const std::vector<Label> &input,
447 const std::vector<Label> &output,
450 FSTERROR() <<
"Calling LinearFstDataBuilder<>::AddWeight() at error state";
455 bool start_in_middle =
false, end_in_middle =
false;
456 for (
int i = 1; i < input.size(); ++i) {
459 start_in_middle =
true;
462 end_in_middle =
true;
464 if (start_in_middle) {
465 LOG(WARNING) <<
"Ignored: start-of-sentence in the middle of the input!";
471 LOG(WARNING) <<
"Ignored: end-of-sentence in the middle of the input!";
479 bool non_first_start =
false, non_last_end =
false;
480 for (
int i = 1; i < output.size(); ++i) {
482 non_first_start =
true;
486 if (non_first_start) {
487 LOG(WARNING) <<
"Ignored: start-of-sentence not appearing " 488 <<
"as the first label in the output!";
494 LOG(WARNING) <<
"Ignored: end-of-sentence not appearing " 495 <<
"as the last label in the output!";
502 for (
size_t i = 0; i < input.size(); ++i) {
503 Label feat = input[i];
507 FSTERROR() <<
"Feature label must be > 0; got " << feat;
510 feat_groups_[feat].insert(group);
512 for (
size_t i = 0; i < output.size(); ++i) {
513 Label label = output[i];
517 FSTERROR() <<
"Output label must be > 0; got " << label;
522 all_output_labels_.Insert(label);
527 bool added = groups_[group]->AddWeight(input, output, weight);
528 if (groups_[group]->
Error()) {
530 FSTERROR() <<
"FeatureGroupBuilder<>::AddWeight() failed";
539 FSTERROR() <<
"Calling LinearFstDataBuilder<>::Dump() at error state";
544 data->max_future_size_ = max_future_size_;
545 data->max_input_label_ = max_input_label_;
548 data->groups_.resize(groups_.size());
549 for (
int group = 0; group != groups_.size(); ++group) {
551 if (new_group ==
nullptr) {
553 FSTERROR() <<
"Error in dumping group " << group;
556 data->groups_[group].reset(new_group);
557 groups_[group].reset();
558 VLOG(1) <<
"Group " << group <<
": " << new_group->
Stats();
562 data->group_feat_map_.Init(data->NumGroups(), max_input_label_ + 1);
563 for (
Label word = 1; word <= max_input_label_; ++word) {
564 typename std::map<Label, std::set<Label>>::const_iterator it =
565 word_feat_map_.find(word);
566 if (it == word_feat_map_.end())
continue;
567 for (
typename std::set<Label>::const_iterator oit = it->second.begin();
568 oit != it->second.end(); ++oit) {
570 typename std::map<Label, std::set<size_t>>::const_iterator jt =
571 feat_groups_.find(feat);
572 if (jt == feat_groups_.end())
continue;
573 for (std::set<size_t>::const_iterator git = jt->second.begin();
574 git != jt->second.end(); ++git) {
575 size_t group_id = *git;
576 if (!data->group_feat_map_.Set(group_id, word, feat)) {
586 std::vector<typename LinearFstData<A>::InputAttribute> *input_attribs =
587 &data->input_attribs_;
588 std::vector<Label> *output_pool = &data->output_pool_;
589 input_attribs->resize(max_input_label_ + 1);
590 for (
Label word = 0; word <= max_input_label_; ++word) {
591 typename std::map<Label, std::set<Label>>::const_iterator it =
592 word_output_map_.find(word);
593 if (it == word_output_map_.end()) {
594 (*input_attribs)[word].output_begin = 0;
595 (*input_attribs)[word].output_length = 0;
597 (*input_attribs)[word].output_begin = output_pool->size();
598 (*input_attribs)[word].output_length = it->second.size();
599 for (
typename std::set<Label>::const_iterator oit = it->second.begin();
600 oit != it->second.end(); ++oit) {
602 output_pool->push_back(olabel);
609 all_output_labels_.
Begin();
610 it != all_output_labels_.End(); ++it)
611 data->output_set_.push_back(*it);
614 return data.release();
622 Label word,
const std::vector<Label> &features) {
624 FSTERROR() <<
"Calling LinearClassifierFstDataBuilder<>::AddWord() at " 628 bool added = builder_.AddWord(word, features);
629 if (builder_.Error()) error_ =
true;
636 FSTERROR() <<
"Calling LinearClassifierFstDataBuilder<>::AddGroup() at " 640 for (
int i = 0; i < num_classes_; ++i) builder_.AddGroup(0);
641 if (builder_.Error()) {
645 return num_groups_++;
650 size_t group,
const std::vector<Label> &input,
Label pred,
Weight weight) {
652 FSTERROR() <<
"Calling LinearClassifierFstDataBuilder<>::AddWeight() at " 656 if (pred <= 0 || pred > num_classes_) {
657 FSTERROR() <<
"Out-of-range prediction label: " << pred
658 <<
" (num classes = " << num_classes_ <<
")";
662 size_t real_group = group * num_classes_ + pred - 1;
663 bool added = builder_.AddWeight(real_group, input, empty_, weight);
664 if (builder_.Error()) error_ =
true;
672 <<
"Calling LinearClassifierFstDataBuilder<>::Dump() at error state";
685 const std::vector<Label> &output,
688 FSTERROR() <<
"Calling FeatureGroupBuilder<>::AddWeight() at error state";
694 int num_input_start = 0;
695 while (num_input_start < input.size() &&
698 int num_output_start = 0;
699 while (num_output_start < output.size() &&
702 int num_input_end = 0;
703 for (
int i = input.size() - 1;
706 int num_output_end = 0;
707 for (
int i = output.size() - 1;
713 if (input.size() - num_input_start < future_size_) {
714 LOG(WARNING) <<
"Ignored: start-of-sentence in the future!";
719 if (num_input_start > 0 && input.size() - future_size_ - num_input_start <
720 output.size() - num_output_start) {
721 LOG(WARNING) <<
"Ignored: matching start-of-sentence with actual output!";
726 if (num_output_start > 0 && input.size() - future_size_ - num_input_start >
727 output.size() - num_output_start) {
728 LOG(WARNING) <<
"Ignored: matching start-of-sentence with actual input!";
734 if (num_input_end > future_size_ && num_input_end - future_size_ != 1) {
735 LOG(WARNING) <<
"Ignored: matching end-of-sentence with actual output!";
740 if (num_output_end > 0 &&
741 ((input.size() == future_size_ && future_size_ != num_input_end) ||
742 (input.size() > future_size_ &&
743 num_input_end != future_size_ + num_output_end))) {
744 LOG(WARNING) <<
"Ignored: matching end-of-sentence with actual input!";
751 if (num_input_start + num_input_end == input.size() &&
752 num_output_start + num_output_end == output.size()) {
754 <<
"Ignored: feature context consisting of only boundary marks!";
762 int cur = (num_input_start == 0 && num_output_start <= future_size_)
766 size_t ipos = num_input_start;
769 num_output_start <= future_size_ ? 0 : num_output_start - future_size_;
771 size_t iend = !input.empty() ? input.size() - num_output_end : 0,
772 oend = output.size() - num_output_end;
775 if (output.empty() && num_input_end > future_size_)
776 iend = input.size() - num_input_end + future_size_;
782 const size_t effective_input_size = iend - ipos,
783 effective_output_size = oend - opos;
784 if (effective_input_size > effective_output_size) {
785 for (
size_t pad = effective_input_size - effective_output_size; pad != 0;
788 }
else if (effective_input_size < effective_output_size) {
789 for (
size_t pad = effective_output_size - effective_input_size; pad != 0;
794 for (; ipos != iend; ++ipos, ++opos)
802 if (num_output_end > 0 || (output.empty() && num_input_end > future_size_))
803 trie_[cur].final_weight =
Times(trie_[cur].final_weight, weight);
805 trie_[cur].weight =
Times(trie_[cur].weight, weight);
813 FSTERROR() <<
"Calling FeatureGroupBuilder<>::PreAccumulateWeights() " 818 if (max_future_size < future_size_) {
820 FSTERROR() <<
"max_future_size (= " << max_future_size
821 <<
") is smaller the builder's future_size (= " << future_size_
827 if (error_)
return nullptr;
828 PreAccumulateWeights();
834 ret->next_state_.resize(trie_.NumNodes());
835 const Topology &topology = trie_.TrieTopology();
836 for (
int i = 0; i < topology.
NumNodes(); ++i) {
838 while (next != topology.
Root() && topology.
ChildrenOf(next).empty() &&
839 trie_[next].final_weight ==
840 trie_[trie_[next].back_link].final_weight)
841 next = trie_[next].back_link;
842 ret->next_state_[i] = next;
847 ret->trie_.swap(store_trie);
859 for (; parent >= 0; parent = trie_[parent].back_link, ++hop_count) {
860 int next = trie_.Find(parent, label);
866 if (hop !=
nullptr) *hop = hop_count;
882 const Topology &topology = trie_.TrieTopology();
887 q.push(trie_.Root());
888 while (!error_ && !q.empty()) {
889 int parent = q.front();
894 eit != children.end(); ++eit) {
895 const std::pair<InputOutputLabel, int> &edge = *eit;
897 int child = edge.second;
902 trie_[child].back_link =
903 FindFirstMatch(label, trie_[parent].back_link,
nullptr);
912 int only_input_hop, only_output_hop, full_hop;
913 int only_input_link =
920 FindFirstMatch(label, trie_[parent].back_link, &full_hop);
921 if (only_input_link != -1 && only_output_link != -1) {
923 FSTERROR() <<
"Branching back-off chain:\n" 924 <<
"\tnode " << child <<
": " << TriePath(child, topology)
926 <<
"\tcan back-off to node " << only_input_link <<
": " 927 << TriePath(only_input_link, topology) <<
"\n" 928 <<
"\tcan back-off to node " << only_output_link <<
": " 929 << TriePath(only_output_link, topology);
931 }
else if (full_link != -1) {
933 if (full_hop <= only_input_hop && full_hop <= only_output_hop) {
934 trie_[child].back_link = full_link;
941 FSTERROR() <<
"Branching back-off chain:\n" 942 <<
"\tnode " << child <<
": " 943 << TriePath(child, topology) <<
"\n" 944 <<
"\tcan back-off to node " << full_link <<
": " 945 << TriePath(full_link, topology) <<
"\n" 946 <<
"tcan back-off to node " << problem_link <<
": " 947 << TriePath(problem_link, topology);
951 trie_[child].back_link =
952 only_input_link != -1 ? only_input_link : only_output_link;
957 if (trie_[child].back_link == -1) trie_[child].back_link = 0;
965 std::vector<bool> visited(trie_.NumNodes(),
false);
966 visited[trie_.Root()] =
true;
968 for (
size_t i = 0; i != trie_.NumNodes(); ++i) {
969 std::stack<int> back_offs;
970 for (
int j = i; !visited[j]; j = trie_[j].back_link) back_offs.push(j);
971 while (!back_offs.empty()) {
972 int j = back_offs.top();
974 WeightBackLink &node = trie_[j];
975 node.weight =
Times(node.weight, trie_[node.back_link].weight);
977 Times(node.final_weight, trie_[node.back_link].final_weight);
985 const Topology &topology,
int cur,
int target,
986 std::vector<InputOutputLabel> *path)
const {
987 if (cur == target)
return true;
990 eit != children.end(); ++eit) {
991 const std::pair<InputOutputLabel, int> &edge = *eit;
992 path->push_back(edge.first);
993 if (TrieDfs(topology, edge.second, target, path))
return true;
1002 std::vector<InputOutputLabel> labels;
1003 TrieDfs(topology, topology.
Root(), node, &labels);
1005 std::ostringstream strm;
1006 for (
typename std::vector<InputOutputLabel>::const_iterator it =
1008 it != labels.end(); ++it) {
1022 if (syms !=
nullptr) ret += syms->
Find(label);
1024 std::ostringstream strm;
1025 strm << '<' << label << '>
'; 1031 template <class Iterator> 1032 std::string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms) { 1033 if (begin == end) return "<empty>"; 1034 std::ostringstream strm; 1036 for (Iterator it = begin; it != end; ++it) { 1041 strm << TranslateLabel(*it, syms); 1046 template <class Label> 1047 std::string JoinLabels(const std::vector<Label> &labels, 1048 const SymbolTable *syms) { 1049 return JoinLabels(labels.begin(), labels.end(), syms); 1053 typename A::Label GuessStartOrEnd(std::vector<typename A::Label> *sequence, 1054 typename A::Label boundary) { 1055 const size_t length = sequence->size(); 1056 std::vector<bool> non_boundary_on_left(length, false), 1057 non_boundary_on_right(length, false); 1058 for (size_t i = 1; i < length; ++i) { 1059 non_boundary_on_left[i] = 1060 non_boundary_on_left[i - 1] || (*sequence)[i - 1] != boundary; 1061 non_boundary_on_right[length - 1 - i] = non_boundary_on_right[length - i] || 1062 (*sequence)[length - i] != boundary; 1065 for (size_t i = 0; i < length; ++i) { 1066 if ((*sequence)[i] != boundary) continue; 1067 const bool left = non_boundary_on_left[i], right = non_boundary_on_right[i]; 1068 if (left && right) { 1069 // Boundary in the middle 1070 LOG(WARNING) << "Boundary label in the middle of the sequence! position: " 1071 << i << "; boundary: " << boundary 1072 << "; sequence: " << JoinLabels(*sequence, nullptr); 1074 << "This is an invalid sequence anyway so I will set it to start."; 1075 (*sequence)[i] = LinearFstData<A>::kStartOfSentence; 1076 } else if (left && !right) { 1078 (*sequence)[i] = LinearFstData<A>::kEndOfSentence; 1079 } else if (!left && right) { 1080 // Can only be start 1081 (*sequence)[i] = LinearFstData<A>::kStartOfSentence; 1083 // !left && !right; can't really tell
1092 #endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ std::string Stats() const
FeatureGroup< A > * Dump(size_t max_future_size)
FeatureGroupBuilder(size_t future_size, const SymbolTable *fsyms, const SymbolTable *osyms)
ErrorWeight Times(const ErrorWeight &, const ErrorWeight &)
bool AddWeight(size_t group, const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
bool AddWord(Label word, const std::vector< Label > &features)
LinearFstData< A > * Dump()
bool AddWeight(const std::vector< Label > &input, const std::vector< Label > &output, Weight weight)
int AddGroup(size_t future_size)
LinearClassifierFstDataBuilder(size_t num_classes, const SymbolTable *isyms=nullptr, const SymbolTable *fsyms=nullptr, const SymbolTable *osyms=nullptr)
bool AddWord(Label word, const std::vector< Label > &features)
std::string TranslateLabel(int64_t label, const SymbolTable *syms)
LinearFstDataBuilder(const SymbolTable *isyms=nullptr, const SymbolTable *fsyms=nullptr, const SymbolTable *osyms=nullptr)
constexpr int kNoTrieNodeId
std::string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms)
bool AddWeight(size_t group, const std::vector< Label > &input, Label pred, Weight weight)
const NextMap & ChildrenOf(int parent) const
std::unordered_map< L, int, H > NextMap
const_iterator Begin() const
std::string Find(int64_t key) const
A::Label GuessStartOrEnd(std::vector< typename A::Label > *sequence, typename A::Label boundary)
LinearFstData< A > * Dump()