Skip to content

Commit

Permalink
Add automate downloading of prebuilt indexes (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
qguo96 authored Sep 25, 2020
1 parent 6b371ed commit 9ce3b00
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 5 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ for i in range(0, 10):
print(f'{i+1:2} {hits2[i].docid:15} {hits2[i].score:.5f}')
```

## How Do I build searchers?

There are two ways to build searchers.
+ To build a searcher with a path to a directory containing the index.
```python
searcher = SimpleSearcher('indexes/index-robust04-20191213/')
```
+ To build a searcher with the index's identifier name.
```python
searcher = SimpleSearcher.from_prebuilt_index('trec45')
```
It currently supports:
+ trec45 (TREC Disks 4 & 5)
+ robust04 (TREC Disks 4 & 5)
+ ms-marco-passage (MS MARCO Passage)
+ ms-marco-doc (MS MARCO Doc)

## How Do I Fetch a Document?

The other commonly used feature is to fetch a document given its `docid`.
Expand Down
6 changes: 6 additions & 0 deletions pyserini/search/_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pyserini.pyclass import autoclass, JString, JArrayList
from pyserini.trectools import TrecRun
from pyserini.fusion import FusionMethod, reciprocal_rank_fusion
from pyserini.util import download_prebuilt_index

logger = logging.getLogger(__name__)

Expand All @@ -48,6 +49,11 @@ def __init__(self, index_dir: str):
self.object = JSimpleSearcher(JString(index_dir))
self.num_docs = self.object.getTotalNumDocuments()

@classmethod
def from_prebuilt_index(cls, prebuilt_index_name: str):
index_dir = download_prebuilt_index(prebuilt_index_name)
return cls(index_dir)

def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, strip_segment_id=False, remove_dups=False) -> List[JSimpleSearcherResult]:
"""Search the collection.
Expand Down
58 changes: 53 additions & 5 deletions pyserini/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@
from tqdm import tqdm
from urllib.request import urlretrieve

INDEX_INFO = {
'index-marco-passage': {
'urls': {'uwaterloo': 'https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-msmarco-passage-20191117-0ed488.tar.gz'},
'md5': '3c2ef64ee6d0ee8e317adcb341b92e28'},
'index-marco-doc': {
'urls': {'dropbox': 'https://www.dropbox.com/s/awukuo8c0tkl9sc/index-msmarco-doc-20200527-a1ecfa.tar.gz?dl=1'},
'md5': '72b1a0f9a9094a86d15c6f4babf8967a'},
'index-robust04': {
'urls': {'uwaterloo': 'https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-robust04-20191213.tar.gz'},
'md5': '15f3d001489c97849a010b0a4734d018'}
}

INDEX_MAPPING = {
'ms-marco-passage': INDEX_INFO['index-marco-passage'],
'ms-marco-doc': INDEX_INFO['index-marco-doc'],
'trec45': INDEX_INFO['index-robust04'],
'robust04': INDEX_INFO['index-robust04']
}


# https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
class TqdmUpTo(tqdm):
Expand Down Expand Up @@ -71,14 +90,27 @@ def download_url(url, save_dir, md5=None, force=False, verbose=True):
if md5:
assert compute_md5(destination_path) == md5, f'{destination_path} does not match checksum!'

def get_cache_home():
return os.path.expanduser(os.path.join(f'~{os.path.sep}.cache', "pyserini"))

def download_and_unpack_index(url, index_directory='indexes', force=False, verbose=True):
def download_and_unpack_index(url, index_directory='indexes', force=False, verbose=True, prebuilt=False, md5=None):
index_name = url.split('/')[-1]
index_name = re.sub('''.tar.gz.*$''', '', index_name)

index_path = f'{index_directory}/{index_name}'
local_tarball = f'{index_directory}/{index_name}.tar.gz'
if prebuilt:
index_directory = os.path.join(get_cache_home(), 'indexes')
index_path = os.path.join(index_directory, f'{index_name}{md5}')
local_tarball = os.path.join(index_directory, f'{index_name}.tar.gz')
if not os.path.exists(index_directory):
os.makedirs(index_directory)
else:
index_path = os.path.join(index_directory, f'{index_name}')

local_tarball = os.path.join(index_directory, f'{index_name}.tar.gz')

if prebuilt:
if os.path.exists(local_tarball):
os.remove(local_tarball)
if verbose:
print(f'Downloading index at {url}...')

Expand All @@ -90,16 +122,32 @@ def download_and_unpack_index(url, index_directory='indexes', force=False, verbo
if not force:
if verbose:
print(f'Skipping download.')
return
return index_path
if verbose:
print(f'force=True, removing {index_path}; fetching fresh copy...')
shutil.rmtree(index_path)

download_url(url, index_directory, verbose=False)
download_url(url, index_directory, verbose=False, md5=md5)

if verbose:
print(f'Extracting {local_tarball} into {index_path}...')
tarball = tarfile.open(local_tarball)
tarball.extractall(index_directory)
tarball.close()
os.remove(local_tarball)
if prebuilt:
os.rename(os.path.join(index_directory, f'{index_name}'), index_path)
return index_path


def download_prebuilt_index(index_name, force=False, verbose=True, mirror=None):
if index_name in INDEX_MAPPING:
if not mirror:
mirror = next(iter(INDEX_MAPPING[index_name]["urls"]))
elif mirror not in INDEX_MAPPING[index_name]["urls"]:
raise ValueError("unrecognized mirror name {}".format(mirror))
index_url = INDEX_MAPPING[index_name]["urls"][mirror]
index_md5 = INDEX_MAPPING[index_name]["md5"]
return download_and_unpack_index(index_url, prebuilt=True, md5=index_md5)
else:
raise ValueError("unrecognized index name {}".format(index_name))

0 comments on commit 9ce3b00

Please sign in to comment.