1
+ #include " guess/guesser.h"
2
+
3
+ #include < cmath>
4
+ #include < cinttypes>
5
+ #include < fstream>
6
+ #include < string>
7
+ #include < vector>
8
+
9
+ #include " guess/string_util.h"
10
+
11
+ namespace guess {
12
+
13
+ guesser::guesser (std::vector<std::string> const & candidates)
14
+ : candidates_(candidates) {
15
+ for (auto & candidate : candidates_) {
16
+ normalize (candidate);
17
+ }
18
+ }
19
+
20
+ std::vector<int > guesser::guess (std::string in, int count) const {
21
+ auto matches = match_trigrams (in);
22
+ matches.resize (std::min (100ul , matches.size ()));
23
+
24
+ score_exact_word_matches (in, matches);
25
+ matches.resize (count);
26
+
27
+ std::vector<int > ret (matches.size ());
28
+ for (int i = 0 ; i < matches.size (); ++i) {
29
+ ret[i] = matches[i].index ;
30
+ }
31
+
32
+ return ret;
33
+ }
34
+
35
+ std::vector<guesser::match> guesser::match_trigrams (std::string& in) const {
36
+ std::vector<match> matches;
37
+ matches.reserve (candidates_.size ());
38
+ for (int i = 0 ; i < candidates_.size (); ++i) {
39
+ matches.emplace_back (i);
40
+ }
41
+
42
+ normalize (in);
43
+
44
+ char const * input = in.c_str ();
45
+ double sqrt_len_vec_input = std::sqrt (in.size () - 2 );
46
+
47
+ char trigram_input[4 ] = {0 };
48
+ char trigram_candidate[4 ] = {0 };
49
+
50
+ for (int i = 0 ; i < candidates_.size (); ++i) {
51
+ int match_count = 0 ;
52
+ const int len_vec_candidate = candidates_[i].length () - 2 ;
53
+
54
+ char const * substr_input = input;
55
+ while (substr_input[2 ] != ' \0 ' ) {
56
+ trigram_input[0 ] = substr_input[0 ];
57
+ trigram_input[1 ] = substr_input[1 ];
58
+ trigram_input[2 ] = substr_input[2 ];
59
+ ++substr_input;
60
+
61
+ char const * substr_candidate = candidates_[i].c_str ();
62
+ while (substr_candidate[2 ] != ' \0 ' ) {
63
+ trigram_candidate[0 ] = substr_candidate[0 ];
64
+ trigram_candidate[1 ] = substr_candidate[1 ];
65
+ trigram_candidate[2 ] = substr_candidate[2 ];
66
+ ++substr_candidate;
67
+
68
+ if (*(uint32_t *) trigram_input == *(uint32_t *) trigram_candidate) {
69
+ ++match_count;
70
+ break ;
71
+ }
72
+ }
73
+ }
74
+
75
+ double denominator = sqrt_len_vec_input * std::sqrt (len_vec_candidate);
76
+ matches[i].cos_sim = match_count / denominator;
77
+ }
78
+ std::sort (std::begin (matches), std::end (matches));
79
+
80
+ return matches;
81
+ }
82
+
83
+ void guesser::score_exact_word_matches (std::string& in,
84
+ std::vector<match>& matches) const {
85
+ for (int i = 0 ; i < matches.size (); ++i) {
86
+ auto & candidate = candidates_[matches[i].index ];
87
+ for_each_token (in, [&](char * input_token) {
88
+ for_each_token (candidate, [&](char * candidate_token) {
89
+ if (strcmp (candidate_token, input_token) == 0 ) {
90
+ matches[i].cos_sim *= 1.33 ;
91
+ return true ;
92
+ }
93
+ return false ;
94
+ });
95
+ return false ;
96
+ });
97
+ }
98
+ std::sort (std::begin (matches), std::end (matches));
99
+ }
100
+
101
+ void guesser::normalize (std::string& s) {
102
+ replace_all (s, " Ä" , " a" );
103
+ replace_all (s, " ä" , " a" );
104
+ replace_all (s, " Ö" , " o" );
105
+ replace_all (s, " ö" , " o" );
106
+ replace_all (s, " Ü" , " u" );
107
+ replace_all (s, " ü" , " u" );
108
+ replace_all (s, " ß" , " ss" );
109
+ replace_all (s, " -" , " " );
110
+ replace_all (s, " /" , " " );
111
+ replace_all (s, " ." , " " );
112
+ replace_all (s, " ," , " " );
113
+ replace_all (s, " (" , " " );
114
+ replace_all (s, " )" , " " );
115
+
116
+ for (int i = 0 ; i < s.length (); ++i) {
117
+ if (!isalnum (s[i])) {
118
+ s[i] = ' ' ;
119
+ }
120
+ }
121
+
122
+ replace_all (s, " " , " " );
123
+
124
+ std::transform (s.begin (), s.end (), s.begin (), ::tolower);
125
+ }
126
+
127
+ } // namespace guess
0 commit comments