Skip to content

Commit

Permalink
Add SimpleNearestNeighborSearcher to expose in Python (castorini#1078)
Browse files Browse the repository at this point in the history
  • Loading branch information
tteofili authored Apr 4, 2020
1 parent fc42ce2 commit 6f4b9bf
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public static void main(String[] args) throws Exception {
Set<String> observations = new HashSet<>();
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String wordValue = document.get(IndexVectors.FIELD_WORD);
String wordValue = document.get(IndexVectors.FIELD_ID);
observations.add(wordValue);
}
double intersection = Sets.intersection(truth, observations).size();
Expand Down
90 changes: 59 additions & 31 deletions src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.lucene.queries.CommonTermsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.store.Directory;
Expand All @@ -39,6 +41,8 @@
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;

import static org.apache.lucene.search.BooleanClause.Occur.SHOULD;
Expand All @@ -48,7 +52,7 @@ public class ApproximateNearestNeighborSearch {
private static final String LEXLSH = "lexlsh";

public static final class Args {
@Option(name = "-input", metaVar = "[file]", required = true, usage = "word vectors model")
@Option(name = "-input", metaVar = "[file]", usage = "vectors model")
public File input;

@Option(name = "-path", metaVar = "[path]", required = true, usage = "index path")
Expand All @@ -57,6 +61,9 @@ public static final class Args {
@Option(name = "-word", metaVar = "[word]", required = true, usage = "input word")
public String word;

@Option(name="-stored", metaVar = "[boolean]", usage = "fetch stored vectors from index")
public boolean stored;

@Option(name = "-encoding", metaVar = "[word]", required = true, usage = "encoding must be one of {fw, lexlsh}")
public String encoding;

Expand Down Expand Up @@ -114,9 +121,10 @@ public static void main(String[] args) throws Exception {
return;
}

System.out.println(String.format("Loading model %s", indexArgs.input));

Map<String, float[]> wordVectors = IndexVectors.readGloVe(indexArgs.input);
if (!indexArgs.stored && indexArgs.input == null) {
System.err.println("Either -path or -stored args must be set");
return;
}

Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Expand All @@ -132,39 +140,59 @@ public static void main(String[] args) throws Exception {
searcher.setSimilarity(new ClassicSimilarity());
}

float[] vector = wordVectors.get(indexArgs.word);
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
Collection<String> vectors = new LinkedList<>();
if (indexArgs.stored) {
TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, indexArgs.word)), indexArgs.depth);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
vectors.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR));
}
} else {
System.out.println(String.format("Loading model %s", indexArgs.input));

Map<String, float[]> wordVectors = IndexVectors.readGloVe(indexArgs.input);

if (wordVectors.containsKey(indexArgs.word)) {
float[] vector = wordVectors.get(indexArgs.word);
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
String vectorString = sb.toString();
vectors.add(vectorString);
}
sb.append(fv);
}
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff);
for (String token : AnalyzerUtils.analyze(vectorAnalyzer, sb.toString())) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
if (indexArgs.msm > 0) {
simQuery.setHighFreqMinimumNumberShouldMatch(indexArgs.msm);
simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm);
}

long start = System.currentTimeMillis();
TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
searcher.search(simQuery, results);
long time = System.currentTimeMillis() - start;
for (String vectorString : vectors) {
float msm = indexArgs.msm;
float cutoff = indexArgs.cutoff;
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, cutoff);
for (String token : AnalyzerUtils.analyze(vectorAnalyzer, vectorString)) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
if (msm > 0) {
simQuery.setHighFreqMinimumNumberShouldMatch(msm);
simQuery.setLowFreqMinimumNumberShouldMatch(msm);
}

long start = System.currentTimeMillis();
TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
searcher.search(simQuery, results);
long time = System.currentTimeMillis() - start;

System.out.println(String.format("%d nearest neighbors of '%s':", indexArgs.depth, indexArgs.word));
System.out.println(String.format("%d nearest neighbors of '%s':", indexArgs.depth, indexArgs.word));

int rank = 1;
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String word = document.get(IndexVectors.FIELD_WORD);
System.out.println(String.format("%d. %s (%.3f)", rank, word, sd.score));
rank++;
int rank = 1;
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String word = document.get(IndexVectors.FIELD_ID);
System.out.println(String.format("%d. %s (%.3f)", rank, word, sd.score));
rank++;
}
System.out.println(String.format("Search time: %dms", time));
}
System.out.println(String.format("Search time: %dms", time));

reader.close();
d.close();
}
Expand Down
29 changes: 16 additions & 13 deletions src/main/java/io/anserini/ann/IndexVectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,24 @@
import java.util.concurrent.atomic.AtomicInteger;

public class IndexVectors {
static final String FIELD_WORD = "word";
static final String FIELD_VECTOR = "vector";
public static final String FIELD_ID = "id";
public static final String FIELD_VECTOR = "vector";

private static final String FW = "fw";
private static final String LEXLSH = "lexlsh";
public static final String FW = "fw";
public static final String LEXLSH = "lexlsh";

public static final class Args {
@Option(name = "-input", metaVar = "[file]", required = true, usage = "word vectors model")
@Option(name = "-input", metaVar = "[file]", required = true, usage = "vectors model")
public File input;

@Option(name = "-path", metaVar = "[path]", required = true, usage = "index path")
public Path path;

@Option(name = "-encoding", metaVar = "[word]", required = true, usage = "encoding must be one of {fw, lexlsh}")
public String encoding;
public String encoding = FW;

@Option(name="-stored", metaVar = "[boolean]", usage = "store vectors")
public boolean stored;

@Option(name = "-lexlsh.n", metaVar = "[int]", usage = "ngrams")
public int ngrams = 2;
Expand All @@ -81,7 +84,7 @@ public static final class Args {
public int bucketCount = 300;

@Option(name = "-fw.q", metaVar = "[int]", usage = "quantization factor")
public int q = 60;
public int q = FakeWordsEncoderAnalyzer.DEFAULT_Q;
}

public static void main(String[] args) throws Exception {
Expand Down Expand Up @@ -113,7 +116,7 @@ public static void main(String[] args) throws Exception {
final long start = System.nanoTime();
System.out.println(String.format("Loading model %s", indexArgs.input));

Map<String, float[]> wordVectors = readGloVe(indexArgs.input);
Map<String, float[]> vectors = readGloVe(indexArgs.input);

Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Expand All @@ -131,10 +134,10 @@ public static void main(String[] args) throws Exception {
IndexWriter indexWriter = new IndexWriter(d, conf);
final AtomicInteger cnt = new AtomicInteger();

for (Map.Entry<String, float[]> entry : wordVectors.entrySet()) {
for (Map.Entry<String, float[]> entry : vectors.entrySet()) {
Document doc = new Document();

doc.add(new StringField(FIELD_WORD, entry.getKey(), Field.Store.YES));
doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES));
float[] vector = entry.getValue();
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
Expand All @@ -143,20 +146,20 @@ public static void main(String[] args) throws Exception {
}
sb.append(fv);
}
doc.add(new TextField(FIELD_VECTOR, sb.toString(), Field.Store.NO));
doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO));
try {
indexWriter.addDocument(doc);
int cur = cnt.incrementAndGet();
if (cur % 100000 == 0) {
System.out.println(String.format("%s words added", cnt));
System.out.println(String.format("%s docs added", cnt));
}
} catch (IOException e) {
System.err.println("Error while indexing: " + e.getLocalizedMessage());
}
}

indexWriter.commit();
System.out.println(String.format("%s words indexed", cnt.get()));
System.out.println(String.format("%s docs indexed", cnt.get()));
long space = FileUtils.sizeOfDirectory(indexDir.toFile()) / (1024L * 1024L);
System.out.println(String.format("Index size: %dMB", space));
indexWriter.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class FakeWordsEncoderAnalyzer extends Analyzer {

static final String REMOVE_IT = "_";

private static final int DEFAULT_Q = 60;
public static final int DEFAULT_Q = 80;

private final int q;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
*/
public class LexicalLshAnalyzer extends Analyzer {

private static final int DEFAULT_SHINGLE_SIZE = 5;
private static final int DEFAULT_SHINGLE_SIZE = 2;
private static final int DEFAULT_DECIMALS = 1;

private final int min;
Expand Down
101 changes: 101 additions & 0 deletions src/main/java/io/anserini/search/SimpleNearestNeighborSearcher.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Anserini: A Lucene toolkit for replicable information retrieval research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.anserini.search;

import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import io.anserini.analysis.AnalyzerUtils;
import io.anserini.ann.IndexVectors;
import io.anserini.ann.fw.FakeWordsEncoderAnalyzer;
import io.anserini.ann.lexlsh.LexicalLshAnalyzer;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.queries.CommonTermsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;

import static org.apache.lucene.search.BooleanClause.Occur.SHOULD;

public class SimpleNearestNeighborSearcher {

private final Analyzer analyzer;
private final IndexSearcher searcher;

public SimpleNearestNeighborSearcher(String path) throws IOException {
this(path, IndexVectors.FW);
}

public SimpleNearestNeighborSearcher(String path, String encoding) throws IOException {
Directory d = FSDirectory.open(Paths.get(path));
DirectoryReader reader = DirectoryReader.open(d);
searcher = new IndexSearcher(reader);
if (encoding.equalsIgnoreCase(IndexVectors.LEXLSH)) {
analyzer = new LexicalLshAnalyzer();
} else if (encoding.equalsIgnoreCase(IndexVectors.FW)) {
analyzer = new FakeWordsEncoderAnalyzer();
searcher.setSimilarity(new ClassicSimilarity());
} else {
throw new RuntimeException("unexpected encoding " + encoding);
}
}

public Result[][] search(String word, int k) throws IOException {
List<Result[]> results = new ArrayList<>();
TopDocs wordDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, word)), k);

for (ScoreDoc scoreDoc : wordDocs.scoreDocs) {
Document doc = searcher.doc(scoreDoc.doc);
String vector = doc.get(IndexVectors.FIELD_VECTOR);
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, 0);
List<String> tokens = AnalyzerUtils.analyze(analyzer, vector);
for (String token : tokens) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
TopDocs nearest = searcher.search(simQuery, k);
Result[] neighbors = new Result[nearest.scoreDocs.length];
int i = 0;
for (ScoreDoc nn : nearest.scoreDocs) {
Document ndoc = searcher.doc(nn.doc);
neighbors[i] = new Result(ndoc.get(IndexVectors.FIELD_ID), nn.score);
i++;
}
results.add(neighbors);
}
return results.toArray(new Result[0][0]);
}

public static class Result {

public final String id;
public final float score;

private Result(String id, float score) {
this.id = id;
this.score = score;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class ApproximateNearestNeighborEvalTest {
public void evalFWTest() throws Exception {
String path = "target/idx-sample-fw";
String encoding = "fw";
IndexVectorsTest.createIndex(path, encoding);
IndexVectorsTest.createIndex(path, encoding, false);
String[] args = new String[]{"-encoding", encoding, "-input", "src/test/resources/mini-word-vectors.txt", "-path",
path, "-topics", "src/test/resources/sample_topics/Trec"};
ApproximateNearestNeighborEval.main(args);
Expand All @@ -37,7 +37,7 @@ public void evalFWTest() throws Exception {
public void evalLLTest() throws Exception {
String path = "target/idx-sample-ll";
String encoding = "lexlsh";
IndexVectorsTest.createIndex(path, encoding);
IndexVectorsTest.createIndex(path, encoding, false);
String[] args = new String[]{"-encoding", encoding, "-input", "src/test/resources/mini-word-vectors.txt", "-path",
path, "-topics", "src/test/resources/sample_topics/Trec"};
ApproximateNearestNeighborEval.main(args);
Expand Down
Loading

0 comments on commit 6f4b9bf

Please sign in to comment.