Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ANNSearcher class for easier ANN usage from code #1078

Merged
merged 7 commits into from
Apr 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
changes to reflect JL comments
  • Loading branch information
tteofili committed Apr 4, 2020
commit ed3eb85238b1e2565bf5068723dd2b8ed18daad6
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public static void main(String[] args) throws Exception {
Set<String> observations = new HashSet<>();
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String wordValue = document.get(IndexVectors.FIELD_WORD);
String wordValue = document.get(IndexVectors.FIELD_ID);
observations.add(wordValue);
}
double intersection = Sets.intersection(truth, observations).size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;
Expand Down Expand Up @@ -143,7 +142,7 @@ public static void main(String[] args) throws Exception {

Collection<String> vectors = new LinkedList<>();
if (indexArgs.stored) {
TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_WORD, indexArgs.word)), indexArgs.depth);
TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, indexArgs.word)), indexArgs.depth);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
vectors.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR));
}
Expand Down Expand Up @@ -188,7 +187,7 @@ public static void main(String[] args) throws Exception {
int rank = 1;
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String word = document.get(IndexVectors.FIELD_WORD);
String word = document.get(IndexVectors.FIELD_ID);
System.out.println(String.format("%d. %s (%.3f)", rank, word, sd.score));
rank++;
}
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/io/anserini/ann/IndexVectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import java.util.concurrent.atomic.AtomicInteger;

public class IndexVectors {
public static final String FIELD_WORD = "word";
public static final String FIELD_ID = "id";
public static final String FIELD_VECTOR = "vector";

public static final String FW = "fw";
Expand All @@ -63,7 +63,7 @@ public static final class Args {
public Path path;

@Option(name = "-encoding", metaVar = "[word]", required = true, usage = "encoding must be one of {fw, lexlsh}")
public String encoding;
public String encoding = FW;

@Option(name="-stored", metaVar = "[boolean]", usage = "store vectors")
public boolean stored;
Expand Down Expand Up @@ -116,7 +116,7 @@ public static void main(String[] args) throws Exception {
final long start = System.nanoTime();
System.out.println(String.format("Loading model %s", indexArgs.input));

Map<String, float[]> wordVectors = readGloVe(indexArgs.input);
Map<String, float[]> vectors = readGloVe(indexArgs.input);

Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Expand All @@ -134,10 +134,10 @@ public static void main(String[] args) throws Exception {
IndexWriter indexWriter = new IndexWriter(d, conf);
final AtomicInteger cnt = new AtomicInteger();

for (Map.Entry<String, float[]> entry : wordVectors.entrySet()) {
for (Map.Entry<String, float[]> entry : vectors.entrySet()) {
Document doc = new Document();

doc.add(new StringField(FIELD_WORD, entry.getKey(), Field.Store.YES));
doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES));
float[] vector = entry.getValue();
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
Expand All @@ -151,15 +151,15 @@ public static void main(String[] args) throws Exception {
indexWriter.addDocument(doc);
int cur = cnt.incrementAndGet();
if (cur % 100000 == 0) {
System.out.println(String.format("%s words added", cnt));
System.out.println(String.format("%s docs added", cnt));
}
} catch (IOException e) {
System.err.println("Error while indexing: " + e.getLocalizedMessage());
}
}

indexWriter.commit();
System.out.println(String.format("%s words indexed", cnt.get()));
System.out.println(String.format("%s docs indexed", cnt.get()));
long space = FileUtils.sizeOfDirectory(indexDir.toFile()) / (1024L * 1024L);
System.out.println(String.format("Index size: %dMB", space));
indexWriter.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@

import static org.apache.lucene.search.BooleanClause.Occur.SHOULD;

public class ANNSearcher {
public class SimpleNearestNeighborSearcher {

private final Analyzer analyzer;
private final IndexSearcher searcher;

public ANNSearcher(String path) throws IOException {
public SimpleNearestNeighborSearcher(String path) throws IOException {
this(path, IndexVectors.FW);
}

public ANNSearcher(String path, String encoding) throws IOException {
public SimpleNearestNeighborSearcher(String path, String encoding) throws IOException {
Directory d = FSDirectory.open(Paths.get(path));
DirectoryReader reader = DirectoryReader.open(d);
searcher = new IndexSearcher(reader);
Expand All @@ -63,9 +63,9 @@ public ANNSearcher(String path, String encoding) throws IOException {
}
}

public NearestNeighbors[] annSearch(String word, int k) throws IOException {
List<NearestNeighbors> results = new ArrayList<>();
TopDocs wordDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_WORD, word)), k);
public Result[][] search(String word, int k) throws IOException {
List<Result[]> results = new ArrayList<>();
TopDocs wordDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, word)), k);

for (ScoreDoc scoreDoc : wordDocs.scoreDocs) {
Document doc = searcher.doc(scoreDoc.doc);
Expand All @@ -76,53 +76,26 @@ public NearestNeighbors[] annSearch(String word, int k) throws IOException {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
TopDocs nearest = searcher.search(simQuery, k);
NearestNeighbors.Neighbor[] neighbors = new NearestNeighbors.Neighbor[nearest.scoreDocs.length];
Result[] neighbors = new Result[nearest.scoreDocs.length];
int i = 0;
for (ScoreDoc nn : nearest.scoreDocs) {
Document ndoc = searcher.doc(nn.doc);
neighbors[i] = new NearestNeighbors.Neighbor(ndoc.get(IndexVectors.FIELD_WORD), nn.score);
neighbors[i] = new Result(ndoc.get(IndexVectors.FIELD_ID), nn.score);
i++;
}
results.add(new NearestNeighbors(doc.get(IndexVectors.FIELD_WORD), neighbors));
results.add(neighbors);
}
return results.toArray(new NearestNeighbors[0]);
return results.toArray(new Result[0][0]);
}

public static class NearestNeighbors {
public static class Result {

public final String id;
public final Neighbor[] neighbors;
public final float score;

private NearestNeighbors(String id, Neighbor[] neighbors) {
private Result(String id, float score) {
this.id = id;
this.neighbors = neighbors;
}

public static class Neighbor {

public final String id;
public final float score;

private Neighbor(String id, float score) {
this.id = id;
this.score = score;
}

@Override
public String toString() {
return "Neighbor{" +
"id='" + id + '\'' +
", score=" + score +
'}';
}
}

@Override
public String toString() {
return "NearestNeighbors{" +
"id='" + id + '\'' +
", neighbors=" + Arrays.toString(neighbors) +
'}';
this.score = score;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,27 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class ANNSearcherTest {
public class SimpleNearestNeighborSearcherTest {

@Test
public void testSearchingFW() throws Exception {
String idxPath = "target/ast" + System.currentTimeMillis();
IndexVectorsTest.createIndex(idxPath, "fw", true);
ANNSearcher annSearcher = new ANNSearcher(idxPath);
ANNSearcher.NearestNeighbors[] results = annSearcher.annSearch("text", 2);
SimpleNearestNeighborSearcher simpleNearestNeighborSearcher = new SimpleNearestNeighborSearcher(idxPath);
SimpleNearestNeighborSearcher.Result[][] results = simpleNearestNeighborSearcher.search("text", 2);
assertNotNull(results);
assertEquals(1, results.length);
assertEquals(2, results[0].neighbors.length);
assertEquals(2, results[0].length);
}

@Test
public void testSearchingLL() throws Exception {
String idxPath = "target/ast" + System.currentTimeMillis();
IndexVectorsTest.createIndex(idxPath, "lexlsh", true);
ANNSearcher annSearcher = new ANNSearcher(idxPath, "lexlsh");
ANNSearcher.NearestNeighbors[] results = annSearcher.annSearch("text", 2);
SimpleNearestNeighborSearcher simpleNearestNeighborSearcher = new SimpleNearestNeighborSearcher(idxPath, "lexlsh");
SimpleNearestNeighborSearcher.Result[][] results = simpleNearestNeighborSearcher.search("text", 2);
assertNotNull(results);
assertEquals(1, results.length);
assertEquals(1, results[0].neighbors.length);
assertEquals(1, results[0].length);
}
}