diff --git a/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java b/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java index ae08ab8eba..ccc72ff00a 100644 --- a/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java +++ b/src/main/java/io/anserini/ann/ApproximateNearestNeighborEval.java @@ -185,7 +185,7 @@ public static void main(String[] args) throws Exception { Set observations = new HashSet<>(); for (ScoreDoc sd : results.topDocs().scoreDocs) { Document document = reader.document(sd.doc); - String wordValue = document.get(IndexVectors.FIELD_WORD); + String wordValue = document.get(IndexVectors.FIELD_ID); observations.add(wordValue); } double intersection = Sets.intersection(truth, observations).size(); diff --git a/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java b/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java index 8cb82413dc..e41377967e 100644 --- a/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java +++ b/src/main/java/io/anserini/ann/ApproximateNearestNeighborSearch.java @@ -26,6 +26,8 @@ import org.apache.lucene.queries.CommonTermsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.store.Directory; @@ -39,6 +41,8 @@ import java.io.File; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Collection; +import java.util.LinkedList; import java.util.Map; import static org.apache.lucene.search.BooleanClause.Occur.SHOULD; @@ -48,7 +52,7 @@ public class ApproximateNearestNeighborSearch { private static final String LEXLSH = "lexlsh"; public static final class Args { - @Option(name = "-input", metaVar = "[file]", required = true, usage = "word vectors model") + @Option(name = "-input", metaVar = "[file]", usage = "vectors model") public File input; @Option(name = "-path", metaVar = "[path]", required = true, usage = "index path") @@ -57,6 +61,9 @@ public static final class Args { @Option(name = "-word", metaVar = "[word]", required = true, usage = "input word") public String word; + @Option(name="-stored", metaVar = "[boolean]", usage = "fetch stored vectors from index") + public boolean stored; + @Option(name = "-encoding", metaVar = "[word]", required = true, usage = "encoding must be one of {fw, lexlsh}") public String encoding; @@ -114,9 +121,10 @@ public static void main(String[] args) throws Exception { return; } - System.out.println(String.format("Loading model %s", indexArgs.input)); - - Map wordVectors = IndexVectors.readGloVe(indexArgs.input); + if (!indexArgs.stored && indexArgs.input == null) { + System.err.println("Either -path or -stored args must be set"); + return; + } Path indexDir = indexArgs.path; if (!Files.exists(indexDir)) { @@ -132,39 +140,59 @@ public static void main(String[] args) throws Exception { searcher.setSimilarity(new ClassicSimilarity()); } - float[] vector = wordVectors.get(indexArgs.word); - StringBuilder sb = new StringBuilder(); - for (double fv : vector) { - if (sb.length() > 0) { - sb.append(' '); + Collection vectors = new LinkedList<>(); + if (indexArgs.stored) { + TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, indexArgs.word)), indexArgs.depth); + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + vectors.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR)); + } + } else { + System.out.println(String.format("Loading model %s", indexArgs.input)); + + Map wordVectors = IndexVectors.readGloVe(indexArgs.input); + + if (wordVectors.containsKey(indexArgs.word)) { + float[] vector = wordVectors.get(indexArgs.word); + StringBuilder sb = new StringBuilder(); + for (double fv : vector) { + if (sb.length() > 0) { + sb.append(' '); + } + sb.append(fv); + } + String vectorString = sb.toString(); + vectors.add(vectorString); } - sb.append(fv); - } - CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff); - for (String token : AnalyzerUtils.analyze(vectorAnalyzer, sb.toString())) { - simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token)); - } - if (indexArgs.msm > 0) { - simQuery.setHighFreqMinimumNumberShouldMatch(indexArgs.msm); - simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm); } - long start = System.currentTimeMillis(); - TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE); - searcher.search(simQuery, results); - long time = System.currentTimeMillis() - start; + for (String vectorString : vectors) { + float msm = indexArgs.msm; + float cutoff = indexArgs.cutoff; + CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, cutoff); + for (String token : AnalyzerUtils.analyze(vectorAnalyzer, vectorString)) { + simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token)); + } + if (msm > 0) { + simQuery.setHighFreqMinimumNumberShouldMatch(msm); + simQuery.setLowFreqMinimumNumberShouldMatch(msm); + } + + long start = System.currentTimeMillis(); + TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE); + searcher.search(simQuery, results); + long time = System.currentTimeMillis() - start; - System.out.println(String.format("%d nearest neighbors of '%s':", indexArgs.depth, indexArgs.word)); + System.out.println(String.format("%d nearest neighbors of '%s':", indexArgs.depth, indexArgs.word)); - int rank = 1; - for (ScoreDoc sd : results.topDocs().scoreDocs) { - Document document = reader.document(sd.doc); - String word = document.get(IndexVectors.FIELD_WORD); - System.out.println(String.format("%d. %s (%.3f)", rank, word, sd.score)); - rank++; + int rank = 1; + for (ScoreDoc sd : results.topDocs().scoreDocs) { + Document document = reader.document(sd.doc); + String word = document.get(IndexVectors.FIELD_ID); + System.out.println(String.format("%d. %s (%.3f)", rank, word, sd.score)); + rank++; + } + System.out.println(String.format("Search time: %dms", time)); } - System.out.println(String.format("Search time: %dms", time)); - reader.close(); d.close(); } diff --git a/src/main/java/io/anserini/ann/IndexVectors.java b/src/main/java/io/anserini/ann/IndexVectors.java index 1e7aebc5d2..13834805a7 100644 --- a/src/main/java/io/anserini/ann/IndexVectors.java +++ b/src/main/java/io/anserini/ann/IndexVectors.java @@ -49,21 +49,24 @@ import java.util.concurrent.atomic.AtomicInteger; public class IndexVectors { - static final String FIELD_WORD = "word"; - static final String FIELD_VECTOR = "vector"; + public static final String FIELD_ID = "id"; + public static final String FIELD_VECTOR = "vector"; - private static final String FW = "fw"; - private static final String LEXLSH = "lexlsh"; + public static final String FW = "fw"; + public static final String LEXLSH = "lexlsh"; public static final class Args { - @Option(name = "-input", metaVar = "[file]", required = true, usage = "word vectors model") + @Option(name = "-input", metaVar = "[file]", required = true, usage = "vectors model") public File input; @Option(name = "-path", metaVar = "[path]", required = true, usage = "index path") public Path path; @Option(name = "-encoding", metaVar = "[word]", required = true, usage = "encoding must be one of {fw, lexlsh}") - public String encoding; + public String encoding = FW; + + @Option(name="-stored", metaVar = "[boolean]", usage = "store vectors") + public boolean stored; @Option(name = "-lexlsh.n", metaVar = "[int]", usage = "ngrams") public int ngrams = 2; @@ -81,7 +84,7 @@ public static final class Args { public int bucketCount = 300; @Option(name = "-fw.q", metaVar = "[int]", usage = "quantization factor") - public int q = 60; + public int q = FakeWordsEncoderAnalyzer.DEFAULT_Q; } public static void main(String[] args) throws Exception { @@ -113,7 +116,7 @@ public static void main(String[] args) throws Exception { final long start = System.nanoTime(); System.out.println(String.format("Loading model %s", indexArgs.input)); - Map wordVectors = readGloVe(indexArgs.input); + Map vectors = readGloVe(indexArgs.input); Path indexDir = indexArgs.path; if (!Files.exists(indexDir)) { @@ -131,10 +134,10 @@ public static void main(String[] args) throws Exception { IndexWriter indexWriter = new IndexWriter(d, conf); final AtomicInteger cnt = new AtomicInteger(); - for (Map.Entry entry : wordVectors.entrySet()) { + for (Map.Entry entry : vectors.entrySet()) { Document doc = new Document(); - doc.add(new StringField(FIELD_WORD, entry.getKey(), Field.Store.YES)); + doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES)); float[] vector = entry.getValue(); StringBuilder sb = new StringBuilder(); for (double fv : vector) { @@ -143,12 +146,12 @@ public static void main(String[] args) throws Exception { } sb.append(fv); } - doc.add(new TextField(FIELD_VECTOR, sb.toString(), Field.Store.NO)); + doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO)); try { indexWriter.addDocument(doc); int cur = cnt.incrementAndGet(); if (cur % 100000 == 0) { - System.out.println(String.format("%s words added", cnt)); + System.out.println(String.format("%s docs added", cnt)); } } catch (IOException e) { System.err.println("Error while indexing: " + e.getLocalizedMessage()); @@ -156,7 +159,7 @@ public static void main(String[] args) throws Exception { } indexWriter.commit(); - System.out.println(String.format("%s words indexed", cnt.get())); + System.out.println(String.format("%s docs indexed", cnt.get())); long space = FileUtils.sizeOfDirectory(indexDir.toFile()) / (1024L * 1024L); System.out.println(String.format("Index size: %dMB", space)); indexWriter.close(); diff --git a/src/main/java/io/anserini/ann/fw/FakeWordsEncoderAnalyzer.java b/src/main/java/io/anserini/ann/fw/FakeWordsEncoderAnalyzer.java index 214fe1be72..09701dc0be 100644 --- a/src/main/java/io/anserini/ann/fw/FakeWordsEncoderAnalyzer.java +++ b/src/main/java/io/anserini/ann/fw/FakeWordsEncoderAnalyzer.java @@ -31,7 +31,7 @@ public class FakeWordsEncoderAnalyzer extends Analyzer { static final String REMOVE_IT = "_"; - private static final int DEFAULT_Q = 60; + public static final int DEFAULT_Q = 80; private final int q; diff --git a/src/main/java/io/anserini/ann/lexlsh/LexicalLshAnalyzer.java b/src/main/java/io/anserini/ann/lexlsh/LexicalLshAnalyzer.java index 082e39837f..ff4a5d740f 100644 --- a/src/main/java/io/anserini/ann/lexlsh/LexicalLshAnalyzer.java +++ b/src/main/java/io/anserini/ann/lexlsh/LexicalLshAnalyzer.java @@ -32,7 +32,7 @@ */ public class LexicalLshAnalyzer extends Analyzer { - private static final int DEFAULT_SHINGLE_SIZE = 5; + private static final int DEFAULT_SHINGLE_SIZE = 2; private static final int DEFAULT_DECIMALS = 1; private final int min; diff --git a/src/main/java/io/anserini/search/SimpleNearestNeighborSearcher.java b/src/main/java/io/anserini/search/SimpleNearestNeighborSearcher.java new file mode 100644 index 0000000000..e09aa0277c --- /dev/null +++ b/src/main/java/io/anserini/search/SimpleNearestNeighborSearcher.java @@ -0,0 +1,101 @@ +/* + * Anserini: A Lucene toolkit for replicable information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.anserini.search; + +import java.io.IOException; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import io.anserini.analysis.AnalyzerUtils; +import io.anserini.ann.IndexVectors; +import io.anserini.ann.fw.FakeWordsEncoderAnalyzer; +import io.anserini.ann.lexlsh.LexicalLshAnalyzer; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.queries.CommonTermsQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.similarities.ClassicSimilarity; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; + +import static org.apache.lucene.search.BooleanClause.Occur.SHOULD; + +public class SimpleNearestNeighborSearcher { + + private final Analyzer analyzer; + private final IndexSearcher searcher; + + public SimpleNearestNeighborSearcher(String path) throws IOException { + this(path, IndexVectors.FW); + } + + public SimpleNearestNeighborSearcher(String path, String encoding) throws IOException { + Directory d = FSDirectory.open(Paths.get(path)); + DirectoryReader reader = DirectoryReader.open(d); + searcher = new IndexSearcher(reader); + if (encoding.equalsIgnoreCase(IndexVectors.LEXLSH)) { + analyzer = new LexicalLshAnalyzer(); + } else if (encoding.equalsIgnoreCase(IndexVectors.FW)) { + analyzer = new FakeWordsEncoderAnalyzer(); + searcher.setSimilarity(new ClassicSimilarity()); + } else { + throw new RuntimeException("unexpected encoding " + encoding); + } + } + + public Result[][] search(String word, int k) throws IOException { + List results = new ArrayList<>(); + TopDocs wordDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, word)), k); + + for (ScoreDoc scoreDoc : wordDocs.scoreDocs) { + Document doc = searcher.doc(scoreDoc.doc); + String vector = doc.get(IndexVectors.FIELD_VECTOR); + CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, 0); + List tokens = AnalyzerUtils.analyze(analyzer, vector); + for (String token : tokens) { + simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token)); + } + TopDocs nearest = searcher.search(simQuery, k); + Result[] neighbors = new Result[nearest.scoreDocs.length]; + int i = 0; + for (ScoreDoc nn : nearest.scoreDocs) { + Document ndoc = searcher.doc(nn.doc); + neighbors[i] = new Result(ndoc.get(IndexVectors.FIELD_ID), nn.score); + i++; + } + results.add(neighbors); + } + return results.toArray(new Result[0][0]); + } + + public static class Result { + + public final String id; + public final float score; + + private Result(String id, float score) { + this.id = id; + this.score = score; + } + } +} diff --git a/src/test/java/io/anserini/ann/ApproximateNearestNeighborEvalTest.java b/src/test/java/io/anserini/ann/ApproximateNearestNeighborEvalTest.java index a1b1658e8c..40ef4a2a5a 100644 --- a/src/test/java/io/anserini/ann/ApproximateNearestNeighborEvalTest.java +++ b/src/test/java/io/anserini/ann/ApproximateNearestNeighborEvalTest.java @@ -27,7 +27,7 @@ public class ApproximateNearestNeighborEvalTest { public void evalFWTest() throws Exception { String path = "target/idx-sample-fw"; String encoding = "fw"; - IndexVectorsTest.createIndex(path, encoding); + IndexVectorsTest.createIndex(path, encoding, false); String[] args = new String[]{"-encoding", encoding, "-input", "src/test/resources/mini-word-vectors.txt", "-path", path, "-topics", "src/test/resources/sample_topics/Trec"}; ApproximateNearestNeighborEval.main(args); @@ -37,7 +37,7 @@ public void evalFWTest() throws Exception { public void evalLLTest() throws Exception { String path = "target/idx-sample-ll"; String encoding = "lexlsh"; - IndexVectorsTest.createIndex(path, encoding); + IndexVectorsTest.createIndex(path, encoding, false); String[] args = new String[]{"-encoding", encoding, "-input", "src/test/resources/mini-word-vectors.txt", "-path", path, "-topics", "src/test/resources/sample_topics/Trec"}; ApproximateNearestNeighborEval.main(args); diff --git a/src/test/java/io/anserini/ann/ApproximateNearestNeighborSearchTest.java b/src/test/java/io/anserini/ann/ApproximateNearestNeighborSearchTest.java index 9f2fd9dd55..c79f9516f5 100644 --- a/src/test/java/io/anserini/ann/ApproximateNearestNeighborSearchTest.java +++ b/src/test/java/io/anserini/ann/ApproximateNearestNeighborSearchTest.java @@ -25,22 +25,40 @@ public class ApproximateNearestNeighborSearchTest { @Test public void searchFWTest() throws Exception { - String path = "target/idx-sample-fw"; + String path = "target/idx-sample-fw" + System.currentTimeMillis(); String encoding = "fw"; - IndexVectorsTest.createIndex(path, encoding); + IndexVectorsTest.createIndex(path, encoding, false); String[] args = new String[]{"-encoding", encoding, "-input", "src/test/resources/mini-word-vectors.txt", "-path", path, "-word", "foo"}; ApproximateNearestNeighborSearch.main(args); } + @Test + public void searchFWStoredTest() throws Exception { + String path = "target/idx-sample-fw-stored" + System.currentTimeMillis(); + String encoding = "fw"; + IndexVectorsTest.createIndex(path, encoding, true); + String[] args = new String[]{"-encoding", encoding, "-stored", "-path", path, "-word", "foo"}; + ApproximateNearestNeighborSearch.main(args); + } + @Test public void searchLLTest() throws Exception { - String path = "target/idx-sample-ll"; + String path = "target/idx-sample-ll" + System.currentTimeMillis(); String encoding = "lexlsh"; - IndexVectorsTest.createIndex(path, encoding); + IndexVectorsTest.createIndex(path, encoding, false); String[] args = new String[]{"-encoding", encoding, "-input", "src/test/resources/mini-word-vectors.txt", "-path", path, "-word", "foo"}; ApproximateNearestNeighborSearch.main(args); } + @Test + public void searchLLStoredTest() throws Exception { + String path = "target/idx-sample-ll" + System.currentTimeMillis(); + String encoding = "lexlsh"; + IndexVectorsTest.createIndex(path, encoding, true); + String[] args = new String[]{"-encoding", encoding, "-stored", "-path", path, "-word", "foo"}; + ApproximateNearestNeighborSearch.main(args); + } + } \ No newline at end of file diff --git a/src/test/java/io/anserini/ann/IndexVectorsTest.java b/src/test/java/io/anserini/ann/IndexVectorsTest.java index 6671e4cb84..7c935ae30a 100644 --- a/src/test/java/io/anserini/ann/IndexVectorsTest.java +++ b/src/test/java/io/anserini/ann/IndexVectorsTest.java @@ -15,6 +15,9 @@ */ package io.anserini.ann; +import java.util.LinkedList; +import java.util.List; + import org.junit.Test; /** @@ -24,18 +27,36 @@ public class IndexVectorsTest { @Test public void indexFWTest() throws Exception { - createIndex("target/idx-sample-fw", "fw"); + createIndex("target/idx-sample-fw" + System.currentTimeMillis(), "fw", false); + } + + @Test + public void indexFWStoredTest() throws Exception { + createIndex("target/idx-sample-fw" + System.currentTimeMillis(), "fw", false); } @Test public void indexLLTest() throws Exception { - createIndex("target/idx-sample-ll", "lexlsh"); + createIndex("target/idx-sample-ll" + System.currentTimeMillis(), "lexlsh", false); + } + + @Test + public void indexLLStoredTest() throws Exception { + createIndex("target/idx-sample-ll" + System.currentTimeMillis(), "lexlsh", false); } - static void createIndex(String path, String encoding) throws Exception { - String[] args = new String[]{"-encoding", encoding, "-input", "src/test/resources/mini-word-vectors.txt", "-path", - path}; - IndexVectors.main(args); + public static void createIndex(String path, String encoding, boolean stored) throws Exception { + List args = new LinkedList<>(); + args.add("-encoding"); + args.add(encoding); + args.add("-input"); + args.add("src/test/resources/mini-word-vectors.txt"); + args.add("-path"); + args.add(path); + if (stored) { + args.add("-stored"); + } + IndexVectors.main(args.toArray(new String[0])); } diff --git a/src/test/java/io/anserini/search/SimpleNearestNeighborSearcherTest.java b/src/test/java/io/anserini/search/SimpleNearestNeighborSearcherTest.java new file mode 100644 index 0000000000..f406dcb9e2 --- /dev/null +++ b/src/test/java/io/anserini/search/SimpleNearestNeighborSearcherTest.java @@ -0,0 +1,47 @@ +/* + * Anserini: A Lucene toolkit for replicable information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.anserini.search; + +import io.anserini.ann.IndexVectorsTest; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class SimpleNearestNeighborSearcherTest { + + @Test + public void testSearchingFW() throws Exception { + String idxPath = "target/ast" + System.currentTimeMillis(); + IndexVectorsTest.createIndex(idxPath, "fw", true); + SimpleNearestNeighborSearcher simpleNearestNeighborSearcher = new SimpleNearestNeighborSearcher(idxPath); + SimpleNearestNeighborSearcher.Result[][] results = simpleNearestNeighborSearcher.search("text", 2); + assertNotNull(results); + assertEquals(1, results.length); + assertEquals(2, results[0].length); + } + + @Test + public void testSearchingLL() throws Exception { + String idxPath = "target/ast" + System.currentTimeMillis(); + IndexVectorsTest.createIndex(idxPath, "lexlsh", true); + SimpleNearestNeighborSearcher simpleNearestNeighborSearcher = new SimpleNearestNeighborSearcher(idxPath, "lexlsh"); + SimpleNearestNeighborSearcher.Result[][] results = simpleNearestNeighborSearcher.search("text", 2); + assertNotNull(results); + assertEquals(1, results.length); + assertEquals(1, results[0].length); + } +} \ No newline at end of file