Skip to content

Commit ea62b54

Browse files
committed
add minhash to bucketFactory
revert StandardBucket add MinhashBucket switch between 'standard' and 'minhash' in 'dynarank.script_sort.params.bucket_factory'
1 parent fc22707 commit ea62b54

File tree

8 files changed

+472
-39
lines changed

8 files changed

+472
-39
lines changed

src/main/java/org/codelibs/elasticsearch/dynarank/script/DiversitySortScriptEngine.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.logging.log4j.Logger;
1414
import org.codelibs.elasticsearch.dynarank.script.bucket.BucketFactory;
1515
import org.codelibs.elasticsearch.dynarank.script.bucket.Buckets;
16+
import org.codelibs.elasticsearch.dynarank.script.bucket.impl.MinhashBucketFactory;
1617
import org.codelibs.elasticsearch.dynarank.script.bucket.impl.StandardBucketFactory;
1718
import org.elasticsearch.ElasticsearchException;
1819
import org.elasticsearch.common.settings.Setting;
@@ -29,6 +30,8 @@ public class DiversitySortScriptEngine implements ScriptEngine {
2930

3031
private static final String STANDARD = "standard";
3132

33+
private static final String MINHASH = "minhash";
34+
3235
public static final Setting<Settings> SETTING_SCRIPT_DYNARANK_BUCKET =
3336
Setting.groupSetting("script.dynarank.bucket.", Property.NodeScope);
3437

@@ -40,6 +43,7 @@ public DiversitySortScriptEngine(final Settings settings) {
4043

4144
bucketFactories = new HashMap<>();
4245
bucketFactories.put(STANDARD, new StandardBucketFactory(settings));
46+
bucketFactories.put(MINHASH, new MinhashBucketFactory(settings));
4347

4448
for (final String name : bucketSettings.names()) {
4549
try {
@@ -91,9 +95,9 @@ public SearchHit[] execute(SearchHit[] searchHit) {
9195
if (logger.isDebugEnabled()) {
9296
logger.debug("Starting DiversitySortScript...");
9397
}
94-
Object bucketFactoryName = params.get("bucket_factory");
95-
if (bucketFactoryName == null) {
96-
bucketFactoryName = STANDARD;
98+
Object bucketFactoryName = STANDARD;
99+
if (params.get("bucket_factory") != null) {
100+
bucketFactoryName = ((String[]) params.get("bucket_factory"))[0];
97101
}
98102
final BucketFactory bucketFactory = bucketFactories.get(bucketFactoryName);
99103
if (bucketFactory == null) {
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package org.codelibs.elasticsearch.dynarank.script.bucket.impl;
2+
3+
import org.codelibs.elasticsearch.dynarank.script.bucket.Bucket;
4+
import org.codelibs.minhash.MinHash;
5+
import org.elasticsearch.search.SearchHit;
6+
7+
import java.util.LinkedList;
8+
import java.util.Queue;
9+
10+
public class MinhashBucket implements Bucket {
11+
protected Queue<SearchHit> queue = new LinkedList<>();
12+
13+
protected Object hash;
14+
15+
private final float threshold;
16+
17+
private final boolean isMinhash;
18+
19+
public MinhashBucket(final SearchHit hit, final Object hash, final float threshold, final boolean isMinhash) {
20+
this.hash = hash;
21+
this.threshold = threshold;
22+
this.isMinhash = isMinhash;
23+
queue.add(hit);
24+
}
25+
26+
@Override
27+
public void consume() {
28+
queue.poll();
29+
}
30+
31+
@Override
32+
public SearchHit get() {
33+
return queue.peek();
34+
}
35+
36+
@Override
37+
public boolean contains(final Object value) {
38+
if (hash == null) {
39+
return value == null;
40+
}
41+
42+
if (value == null) {
43+
return false;
44+
}
45+
46+
if (!hash.getClass().equals(value.getClass())) {
47+
return false;
48+
}
49+
50+
if (value instanceof String) {
51+
if (isMinhash) {
52+
return MinHash.compare(hash.toString(), value.toString()) >= threshold;
53+
}
54+
return value.toString().equals(hash);
55+
} else if (value instanceof Number) {
56+
return Math.abs(((Number) value).doubleValue() - ((Number) hash).doubleValue()) < threshold;
57+
} else if (value instanceof byte[]) {
58+
final byte[] target = (byte[]) value;
59+
return MinHash.compare((byte[]) hash, target) >= threshold;
60+
}
61+
return false;
62+
}
63+
64+
@Override
65+
public void add(final Object... args) {
66+
queue.add((SearchHit) args[0]);
67+
}
68+
69+
@Override
70+
public int size() {
71+
return queue.size();
72+
}
73+
74+
@Override
75+
public String toString() {
76+
return "MinhashBucket [queue=" + queue + ", hash=" + hash + ", threshold=" + threshold + ", isMinhash=" + isMinhash + "]";
77+
}
78+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.codelibs.elasticsearch.dynarank.script.bucket.impl;
2+
3+
import org.codelibs.elasticsearch.dynarank.script.bucket.Bucket;
4+
import org.codelibs.elasticsearch.dynarank.script.bucket.BucketFactory;
5+
import org.codelibs.elasticsearch.dynarank.script.bucket.Buckets;
6+
import org.elasticsearch.common.settings.Settings;
7+
import org.elasticsearch.search.SearchHit;
8+
9+
import java.util.Map;
10+
11+
public class MinhashBucketFactory implements BucketFactory {
12+
13+
protected Settings settings;
14+
15+
public MinhashBucketFactory(final Settings settings) {
16+
this.settings = settings;
17+
}
18+
19+
@Override
20+
public Buckets createBucketList(final Map<String, Object> params) {
21+
return new MinhashBuckets(this, params);
22+
}
23+
24+
@Override
25+
public Bucket createBucket(final Object... args) {
26+
return new MinhashBucket((SearchHit) args[0], args[1], (float) args[2], (boolean) args[3]);
27+
}
28+
}
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
package org.codelibs.elasticsearch.dynarank.script.bucket.impl;
2+
3+
import org.apache.logging.log4j.LogManager;
4+
import org.apache.logging.log4j.Logger;
5+
import org.codelibs.elasticsearch.dynarank.ranker.RetrySearchException;
6+
import org.codelibs.elasticsearch.dynarank.script.bucket.Bucket;
7+
import org.codelibs.elasticsearch.dynarank.script.bucket.BucketFactory;
8+
import org.codelibs.elasticsearch.dynarank.script.bucket.Buckets;
9+
import org.elasticsearch.ElasticsearchException;
10+
import org.elasticsearch.common.bytes.BytesArray;
11+
import org.elasticsearch.common.bytes.BytesReference;
12+
import org.elasticsearch.common.document.DocumentField;
13+
import org.elasticsearch.common.lucene.search.function.CombineFunction;
14+
import org.elasticsearch.index.query.QueryBuilders;
15+
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
16+
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
17+
import org.elasticsearch.search.SearchHit;
18+
import org.elasticsearch.search.builder.SearchSourceBuilder;
19+
20+
import java.util.ArrayList;
21+
import java.util.Arrays;
22+
import java.util.List;
23+
import java.util.Map;
24+
25+
public class MinhashBuckets implements Buckets {
26+
27+
private static final Logger logger = LogManager.getLogger(MinhashBuckets.class);
28+
29+
protected BucketFactory bucketFactory;
30+
31+
protected Map<String, Object> params;
32+
33+
public MinhashBuckets(final BucketFactory bucketFactory, final Map<String, Object> params) {
34+
this.bucketFactory = bucketFactory;
35+
this.params = params;
36+
}
37+
38+
@Override
39+
public SearchHit[] getHits(final SearchHit[] searchHits) {
40+
SearchHit[] hits = searchHits;
41+
final int length = hits.length;
42+
final String[] diversityFields = (String[]) params.get("diversity_fields");
43+
if (diversityFields == null) {
44+
throw new ElasticsearchException("diversity_fields is null.");
45+
}
46+
final String[] thresholds = (String[]) params.get("diversity_thresholds");
47+
if (thresholds == null) {
48+
throw new ElasticsearchException("diversity_thresholds is null.");
49+
}
50+
final Object sourceAsMap = params.get("source_as_map");
51+
final float[] diversityThresholds = parseFloats(thresholds);
52+
final Object[][] ignoredObjGroups = new Object[diversityFields.length][];
53+
final String[] minhashFields = new String[diversityFields.length];
54+
for (int i = 0; i < diversityFields.length; i++) {
55+
ignoredObjGroups[i] = (String[]) params.get(diversityFields[i] + "_ignored_objects");
56+
if (isMinhashFields(sourceAsMap, diversityFields[i])) {
57+
minhashFields[i] = diversityFields[i];
58+
}
59+
}
60+
61+
if (logger.isDebugEnabled()) {
62+
logger.debug("diversity_fields: {}, : diversity_thresholds{}", diversityFields, thresholds);
63+
}
64+
int maxNumOfBuckets = 0;
65+
int minNumOfBuckets = Integer.MAX_VALUE;
66+
for (int i = diversityFields.length - 1; i >= 0; i--) {
67+
final String diversityField = diversityFields[i];
68+
final boolean isMinhash = Arrays.asList(minhashFields).contains(diversityField);
69+
final float diversityThreshold = diversityThresholds[i];
70+
final Object[] ignoredObjs = ignoredObjGroups[i];
71+
final List<Bucket> bucketList = new ArrayList<>();
72+
for (int j = 0; j < length; j++) {
73+
boolean insert = false;
74+
final SearchHit hit = hits[j];
75+
final Object value = getFieldValue(hit, diversityField);
76+
if (value == null) {
77+
if (logger.isDebugEnabled()) {
78+
logger.debug("diversityField {} does not exist. Reranking is skipped.", diversityField);
79+
}
80+
return hits;
81+
}
82+
if (ignoredObjs != null) {
83+
for (final Object ignoredObj : ignoredObjs) {
84+
if (ignoredObj.equals(value)) {
85+
bucketList.add(bucketFactory.createBucket(hit, value, diversityThreshold, isMinhash));
86+
insert = true;
87+
break;
88+
}
89+
}
90+
}
91+
if (!insert) {
92+
for (final Bucket bucket : bucketList) {
93+
if (bucket.contains(value)) {
94+
bucket.add(hit, value);
95+
insert = true;
96+
break;
97+
}
98+
}
99+
if (!insert) {
100+
bucketList.add(bucketFactory.createBucket(hit, value, diversityThreshold, isMinhash));
101+
}
102+
}
103+
}
104+
if (bucketList.size() > maxNumOfBuckets) {
105+
maxNumOfBuckets = bucketList.size();
106+
}
107+
if (bucketList.size() < minNumOfBuckets) {
108+
minNumOfBuckets = bucketList.size();
109+
}
110+
hits = createHits(length, bucketList);
111+
}
112+
113+
int minBucketThreshold = 0;
114+
int maxBucketThreshold = 0;
115+
116+
final Object minBucketThresholdStr = params.get("min_bucket_threshold");
117+
if (minBucketThresholdStr instanceof String) {
118+
try {
119+
minBucketThreshold = Integer.parseInt(minBucketThresholdStr.toString());
120+
} catch (final NumberFormatException e) {
121+
throw new ElasticsearchException("Invalid value of min_bucket_threshold: " + minBucketThresholdStr.toString(), e);
122+
}
123+
} else if (minBucketThresholdStr instanceof Number) {
124+
minBucketThreshold = ((Number) minBucketThresholdStr).intValue();
125+
}
126+
127+
final Object maxBucketThresholdStr = params.get("max_bucket_threshold");
128+
if (maxBucketThresholdStr instanceof String) {
129+
try {
130+
maxBucketThreshold = Integer.parseInt(maxBucketThresholdStr.toString());
131+
} catch (final NumberFormatException e) {
132+
throw new ElasticsearchException("Invalid value of max_bucket_threshold: " + maxBucketThresholdStr.toString(), e);
133+
}
134+
} else if (maxBucketThresholdStr instanceof Number) {
135+
maxBucketThreshold = ((Number) maxBucketThresholdStr).intValue();
136+
}
137+
138+
if (logger.isDebugEnabled()) {
139+
logger.debug("searchHits: {}, minNumOfBuckets: {}, maxNumOfBuckets: {}, minBucketSize: {}, maxBucketThreshold: {}",
140+
hits.length, minNumOfBuckets, maxNumOfBuckets, minBucketThreshold, maxBucketThreshold);
141+
}
142+
143+
if ((minBucketThreshold > 0 && minBucketThreshold >= minNumOfBuckets)
144+
|| (maxBucketThreshold > 0 && maxBucketThreshold >= maxNumOfBuckets)) {
145+
final Object shuffleSeed = params.get("shuffle_seed");
146+
if (shuffleSeed != null) {
147+
if (logger.isDebugEnabled()) {
148+
logger.debug("minBucketSize: {}", shuffleSeed);
149+
}
150+
throw new RetrySearchException(new RetrySearchException.QueryRewriter() {
151+
private static final long serialVersionUID = 1L;
152+
153+
@Override
154+
public SearchSourceBuilder rewrite(final SearchSourceBuilder source) {
155+
float shuffleWeight = 1;
156+
if (params.get("shuffle_weight") instanceof Number) {
157+
shuffleWeight = ((Number) params.get("shuffle_weight")).floatValue();
158+
}
159+
final Object shuffleBoostMode = params.get("shuffle_boost_mode");
160+
161+
final FunctionScoreQueryBuilder functionScoreQuery = QueryBuilders.functionScoreQuery(source.query(),
162+
new FunctionScoreQueryBuilder.FilterFunctionBuilder[] { new FunctionScoreQueryBuilder.FilterFunctionBuilder(
163+
ScoreFunctionBuilders.randomFunction().seed(shuffleSeed.toString()).setWeight(shuffleWeight)) });
164+
if (shuffleBoostMode != null) {
165+
functionScoreQuery.boostMode(CombineFunction.fromString(shuffleBoostMode.toString()));
166+
}
167+
source.query(functionScoreQuery);
168+
return source;
169+
}
170+
});
171+
}
172+
}
173+
174+
return hits;
175+
}
176+
177+
private Object getFieldValue(final SearchHit hit, final String fieldName) {
178+
final DocumentField field = hit.getFields().get(fieldName);
179+
if (field == null) {
180+
final Map<String, Object> source = hit.getSourceAsMap();
181+
// TODO nested
182+
final Object object = source.get(fieldName);
183+
if (object instanceof String) {
184+
return object;
185+
} else if (object instanceof Number) {
186+
return object;
187+
}
188+
return null;
189+
}
190+
final Object object = field.getValue();
191+
if (object instanceof BytesReference) {
192+
return BytesReference.toBytes((BytesReference) object);
193+
} else if (object instanceof String) {
194+
return object;
195+
} else if (object instanceof Number) {
196+
return object;
197+
} else if (object instanceof BytesArray) {
198+
return ((BytesArray) object).array();
199+
}
200+
return null;
201+
}
202+
203+
private float[] parseFloats(final String[] strings) {
204+
final float[] values = new float[strings.length];
205+
for (int i = 0; i < strings.length; i++) {
206+
values[i] = Float.parseFloat(strings[i]);
207+
}
208+
return values;
209+
}
210+
211+
protected SearchHit[] createHits(final int size, final List<Bucket> bucketList) {
212+
if (logger.isDebugEnabled()) {
213+
logger.debug("{} docs -> {} buckets", size, bucketList.size());
214+
for (int i = 0; i < bucketList.size(); i++) {
215+
final Bucket bucket = bucketList.get(i);
216+
logger.debug(" bucket[{}] -> {} docs", i, bucket.size());
217+
}
218+
}
219+
220+
int pos = 0;
221+
final SearchHit[] newSearchHits = new SearchHit[size];
222+
while (pos < size) {
223+
for (final Bucket bucket : bucketList) {
224+
final SearchHit hit = bucket.get();
225+
if (hit != null) {
226+
newSearchHits[pos] = hit;
227+
pos++;
228+
bucket.consume();
229+
}
230+
}
231+
}
232+
233+
return newSearchHits;
234+
}
235+
236+
@SuppressWarnings("unchecked")
237+
private boolean isMinhashFields(Object sourceAsMap, String field) {
238+
if (sourceAsMap instanceof Map) {
239+
Object propertiesMap = ((Map<String, Object>) sourceAsMap).get("properties");
240+
if (propertiesMap instanceof Map) {
241+
Object fieldMap = ((Map<String, Object>) propertiesMap).get(field);
242+
if (fieldMap instanceof Map) {
243+
Object fieldType = ((Map<String, Object>) fieldMap).get("type");
244+
return fieldType != null && fieldType.toString().equals("minhash");
245+
}
246+
}
247+
}
248+
return false;
249+
}
250+
}

0 commit comments

Comments
 (0)