23
23
import org .apache .lucene .analysis .tokenattributes .CharTermAttribute ;
24
24
import org .apache .lucene .analysis .tokenattributes .OffsetAttribute ;
25
25
import org .apache .lucene .analysis .tokenattributes .PositionIncrementAttribute ;
26
+ import org .apache .lucene .codecs .TermStats ;
26
27
import org .apache .lucene .index .IndexReader ;
27
28
import org .apache .lucene .index .MultiFields ;
28
29
import org .apache .lucene .index .Term ;
48
49
49
50
import static java .lang .Math .log10 ;
50
51
import static java .lang .Math .max ;
52
+ import static java .lang .Math .min ;
51
53
import static java .lang .Math .round ;
52
54
53
55
public final class DirectCandidateGenerator extends CandidateGenerator {
@@ -57,20 +59,20 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
57
59
private final SuggestMode suggestMode ;
58
60
private final TermsEnum termsEnum ;
59
61
private final IndexReader reader ;
60
- private final long dictSize ;
62
+ private final long sumTotalTermFreq ;
61
63
private static final double LOG_BASE = 5 ;
62
64
private final long frequencyPlateau ;
63
65
private final Analyzer preFilter ;
64
66
private final Analyzer postFilter ;
65
67
private final double nonErrorLikelihood ;
66
- private final boolean useTotalTermFrequency ;
67
68
private final CharsRefBuilder spare = new CharsRefBuilder ();
68
69
private final BytesRefBuilder byteSpare = new BytesRefBuilder ();
69
70
private final int numCandidates ;
70
71
71
72
public DirectCandidateGenerator (DirectSpellChecker spellchecker , String field , SuggestMode suggestMode , IndexReader reader ,
72
73
double nonErrorLikelihood , int numCandidates ) throws IOException {
73
- this (spellchecker , field , suggestMode , reader , nonErrorLikelihood , numCandidates , null , null , MultiFields .getTerms (reader , field ));
74
+ this (spellchecker , field , suggestMode , reader , nonErrorLikelihood ,
75
+ numCandidates , null , null , MultiFields .getTerms (reader , field ));
74
76
}
75
77
76
78
public DirectCandidateGenerator (DirectSpellChecker spellchecker , String field , SuggestMode suggestMode , IndexReader reader ,
@@ -83,14 +85,12 @@ public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, S
83
85
this .numCandidates = numCandidates ;
84
86
this .suggestMode = suggestMode ;
85
87
this .reader = reader ;
86
- final long dictSize = terms .getSumTotalTermFreq ();
87
- this .useTotalTermFrequency = dictSize != -1 ;
88
- this .dictSize = dictSize == -1 ? reader .maxDoc () : dictSize ;
88
+ this .sumTotalTermFreq = terms .getSumTotalTermFreq () == -1 ? reader .maxDoc () : terms .getSumTotalTermFreq ();
89
89
this .preFilter = preFilter ;
90
90
this .postFilter = postFilter ;
91
91
this .nonErrorLikelihood = nonErrorLikelihood ;
92
92
float thresholdFrequency = spellchecker .getThresholdFrequency ();
93
- this .frequencyPlateau = thresholdFrequency >= 1.0f ? (int ) thresholdFrequency : (int )( dictSize * thresholdFrequency );
93
+ this .frequencyPlateau = thresholdFrequency >= 1.0f ? (int ) thresholdFrequency : (int ) ( reader . maxDoc () * thresholdFrequency );
94
94
termsEnum = terms .iterator ();
95
95
}
96
96
@@ -99,24 +99,29 @@ public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, S
99
99
*/
100
100
@ Override
101
101
public boolean isKnownWord (BytesRef term ) throws IOException {
102
- return frequency (term ) > 0 ;
102
+ return termStats (term ). docFreq > 0 ;
103
103
}
104
104
105
105
/* (non-Javadoc)
106
106
* @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#frequency(org.apache.lucene.util.BytesRef)
107
107
*/
108
108
@ Override
109
- public long frequency (BytesRef term ) throws IOException {
109
+ public TermStats termStats (BytesRef term ) throws IOException {
110
110
term = preFilter (term , spare , byteSpare );
111
- return internalFrequency (term );
111
+ return internalTermStats (term );
112
112
}
113
113
114
114
115
- public long internalFrequency (BytesRef term ) throws IOException {
115
+ public TermStats internalTermStats (BytesRef term ) throws IOException {
116
116
if (termsEnum .seekExact (term )) {
117
- return useTotalTermFrequency ? termsEnum .totalTermFreq () : termsEnum .docFreq ();
117
+ return new TermStats (termsEnum .docFreq (),
118
+ /**
119
+ * We use the {@link TermsEnum#docFreq()} for fields that don't
120
+ * record the {@link TermsEnum#totalTermFreq()}.
121
+ */
122
+ termsEnum .totalTermFreq () == -1 ? termsEnum .docFreq () : termsEnum .totalTermFreq ());
118
123
}
119
- return 0 ;
124
+ return new TermStats ( 0 , 0 ) ;
120
125
}
121
126
122
127
public String getField () {
@@ -127,15 +132,28 @@ public String getField() {
127
132
public CandidateSet drawCandidates (CandidateSet set ) throws IOException {
128
133
Candidate original = set .originalTerm ;
129
134
BytesRef term = preFilter (original .term , spare , byteSpare );
130
- final long frequency = original .frequency ;
131
- spellchecker .setThresholdFrequency (this .suggestMode == SuggestMode .SUGGEST_ALWAYS ? 0 : thresholdFrequency (frequency , dictSize ));
135
+ if (suggestMode != SuggestMode .SUGGEST_ALWAYS ) {
136
+ /**
137
+ * We use the {@link TermStats#docFreq} to compute the frequency threshold
138
+ * because that's what {@link DirectSpellChecker#suggestSimilar} expects
139
+ * when filtering terms.
140
+ */
141
+ int threshold = thresholdTermFrequency (original .termStats .docFreq );
142
+ if (threshold == Integer .MAX_VALUE ) {
143
+ // the threshold is the max possible frequency so we can skip the search
144
+ return set ;
145
+ }
146
+ spellchecker .setThresholdFrequency (threshold );
147
+ }
148
+
132
149
SuggestWord [] suggestSimilar = spellchecker .suggestSimilar (new Term (field , term ), numCandidates , reader , this .suggestMode );
133
150
List <Candidate > candidates = new ArrayList <>(suggestSimilar .length );
134
151
for (int i = 0 ; i < suggestSimilar .length ; i ++) {
135
152
SuggestWord suggestWord = suggestSimilar [i ];
136
153
BytesRef candidate = new BytesRef (suggestWord .string );
137
- postFilter (new Candidate (candidate , internalFrequency (candidate ), suggestWord .score ,
138
- score (suggestWord .freq , suggestWord .score , dictSize ), false ), spare , byteSpare , candidates );
154
+ TermStats termStats = internalTermStats (candidate );
155
+ postFilter (new Candidate (candidate , termStats ,
156
+ suggestWord .score , score (termStats , suggestWord .score , sumTotalTermFreq ), false ), spare , byteSpare , candidates );
139
157
}
140
158
set .addCandidates (candidates );
141
159
return set ;
@@ -171,28 +189,30 @@ public void nextToken() throws IOException {
171
189
BytesRef term = result .toBytesRef ();
172
190
// We should not use frequency(term) here because it will analyze the term again
173
191
// If preFilter and postFilter are the same analyzer it would fail.
174
- long freq = internalFrequency (term );
175
- candidates .add (new Candidate (result .toBytesRef (), freq , candidate .stringDistance ,
176
- score (candidate .frequency , candidate .stringDistance , dictSize ), false ));
192
+ TermStats termStats = internalTermStats (term );
193
+ candidates .add (new Candidate (result .toBytesRef (), termStats , candidate .stringDistance ,
194
+ score (candidate .termStats , candidate .stringDistance , sumTotalTermFreq ), false ));
177
195
} else {
178
- candidates .add (new Candidate (result .toBytesRef (), candidate .frequency , nonErrorLikelihood ,
179
- score (candidate .frequency , candidate .stringDistance , dictSize ), false ));
196
+ candidates .add (new Candidate (result .toBytesRef (), candidate .termStats , nonErrorLikelihood ,
197
+ score (candidate .termStats , candidate .stringDistance , sumTotalTermFreq ), false ));
180
198
}
181
199
}
182
200
}, spare );
183
201
}
184
202
}
185
203
186
- private double score (long frequency , double errorScore , long dictionarySize ) {
187
- return errorScore * (((double )frequency + 1 ) / ((double )dictionarySize +1 ));
204
+ private double score (TermStats termStats , double errorScore , long dictionarySize ) {
205
+ return errorScore * (((double )termStats . totalTermFreq + 1 ) / ((double )dictionarySize +1 ));
188
206
}
189
207
190
- protected long thresholdFrequency (long termFrequency , long dictionarySize ) {
191
- if (termFrequency > 0 ) {
192
- return max (0 , round (termFrequency * (log10 (termFrequency - frequencyPlateau ) * (1.0 / log10 (LOG_BASE ))) + 1 ));
208
+ // package protected for test
209
+ int thresholdTermFrequency (int docFreq ) {
210
+ if (docFreq > 0 ) {
211
+ return (int ) min (
212
+ max (0 , round (docFreq * (log10 (docFreq - frequencyPlateau ) * (1.0 / log10 (LOG_BASE ))) + 1 )), Integer .MAX_VALUE
213
+ );
193
214
}
194
215
return 0 ;
195
-
196
216
}
197
217
198
218
public abstract static class TokenConsumer {
@@ -249,12 +269,12 @@ public static class Candidate implements Comparable<Candidate> {
249
269
public static final Candidate [] EMPTY = new Candidate [0 ];
250
270
public final BytesRef term ;
251
271
public final double stringDistance ;
252
- public final long frequency ;
272
+ public final TermStats termStats ;
253
273
public final double score ;
254
274
public final boolean userInput ;
255
275
256
- public Candidate (BytesRef term , long frequency , double stringDistance , double score , boolean userInput ) {
257
- this .frequency = frequency ;
276
+ public Candidate (BytesRef term , TermStats termStats , double stringDistance , double score , boolean userInput ) {
277
+ this .termStats = termStats ;
258
278
this .term = term ;
259
279
this .stringDistance = stringDistance ;
260
280
this .score = score ;
@@ -266,7 +286,7 @@ public String toString() {
266
286
return "Candidate [term=" + term .utf8ToString ()
267
287
+ ", stringDistance=" + stringDistance
268
288
+ ", score=" + score
269
- + ", frequency =" + frequency
289
+ + ", termStats =" + termStats
270
290
+ (userInput ? ", userInput" : "" ) + "]" ;
271
291
}
272
292
@@ -305,8 +325,8 @@ public int compareTo(Candidate other) {
305
325
}
306
326
307
327
@ Override
308
- public Candidate createCandidate (BytesRef term , long frequency , double channelScore , boolean userInput ) throws IOException {
309
- return new Candidate (term , frequency , channelScore , score (frequency , channelScore , dictSize ), userInput );
328
+ public Candidate createCandidate (BytesRef term , TermStats termStats , double channelScore , boolean userInput ) throws IOException {
329
+ return new Candidate (term , termStats , channelScore , score (termStats , channelScore , sumTotalTermFreq ), userInput );
310
330
}
311
331
312
332
public static int analyze (Analyzer analyzer , BytesRef toAnalyze , String field , TokenConsumer consumer , CharsRefBuilder spare )
0 commit comments