forked from stanford-futuredata/ColBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
searcher.py
130 lines (99 loc) · 4.94 KB
/
searcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import torch
from tqdm import tqdm
from typing import Union
from colbert.data import Collection, Queries, Ranking
from colbert.modeling.checkpoint import Checkpoint
from colbert.search.index_storage import IndexScorer
from colbert.infra.provenance import Provenance
from colbert.infra.run import Run
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert.infra.launcher import print_memory_stats
import time
TextQueries = Union[str, 'list[str]', 'dict[int, str]', Queries]
class Searcher:
def __init__(self, index, checkpoint=None, collection=None, config=None, index_root=None, verbose:int = 3):
self.verbose = verbose
if self.verbose > 1:
print_memory_stats()
initial_config = ColBERTConfig.from_existing(config, Run().config)
default_index_root = initial_config.index_root_
index_root = index_root if index_root else default_index_root
self.index = os.path.join(index_root, index)
self.index_config = ColBERTConfig.load_from_index(self.index)
self.checkpoint = checkpoint or self.index_config.checkpoint
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(self.checkpoint)
self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config)
self.configure(checkpoint=self.checkpoint)
self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config, verbose=self.verbose)
use_gpu = self.config.total_visible_gpus > 0
if use_gpu:
self.checkpoint = self.checkpoint.cuda()
load_index_with_mmap = self.config.load_index_with_mmap
if load_index_with_mmap and use_gpu:
raise ValueError(f"Memory-mapped index can only be used with CPU!")
self.ranker = IndexScorer(self.index, use_gpu, load_index_with_mmap)
print_memory_stats()
def configure(self, **kw_args):
self.config.configure(**kw_args)
def encode(self, text: TextQueries, full_length_search=False):
queries = text if type(text) is list else [text]
bsize = 128 if len(queries) > 128 else None
self.checkpoint.query_tokenizer.query_maxlen = self.config.query_maxlen
Q = self.checkpoint.queryFromText(queries, bsize=bsize, to_cpu=True, full_length_search=full_length_search)
return Q
def search(self, text: str, k=10, filter_fn=None, full_length_search=False, pids=None):
Q = self.encode(text, full_length_search=full_length_search)
return self.dense_search(Q, k, filter_fn=filter_fn, pids=pids)
def search_all(self, queries: TextQueries, k=10, filter_fn=None, full_length_search=False, qid_to_pids=None):
queries = Queries.cast(queries)
queries_ = list(queries.values())
Q = self.encode(queries_, full_length_search=full_length_search)
return self._search_all_Q(queries, Q, k, filter_fn=filter_fn, qid_to_pids=qid_to_pids)
def _search_all_Q(self, queries, Q, k, filter_fn=None, qid_to_pids=None):
qids = list(queries.keys())
if qid_to_pids is None:
qid_to_pids = {qid: None for qid in qids}
all_scored_pids = [
list(
zip(
*self.dense_search(
Q[query_idx:query_idx+1],
k, filter_fn=filter_fn,
pids=qid_to_pids[qid]
)
)
)
for query_idx, qid in tqdm(enumerate(qids))
]
data = {qid: val for qid, val in zip(queries.keys(), all_scored_pids)}
provenance = Provenance()
provenance.source = 'Searcher::search_all'
provenance.queries = queries.provenance()
provenance.config = self.config.export()
provenance.k = k
return Ranking(data=data, provenance=provenance)
def dense_search(self, Q: torch.Tensor, k=10, filter_fn=None, pids=None):
if k <= 10:
if self.config.ncells is None:
self.configure(ncells=1)
if self.config.centroid_score_threshold is None:
self.configure(centroid_score_threshold=0.5)
if self.config.ndocs is None:
self.configure(ndocs=256)
elif k <= 100:
if self.config.ncells is None:
self.configure(ncells=2)
if self.config.centroid_score_threshold is None:
self.configure(centroid_score_threshold=0.45)
if self.config.ndocs is None:
self.configure(ndocs=1024)
else:
if self.config.ncells is None:
self.configure(ncells=4)
if self.config.centroid_score_threshold is None:
self.configure(centroid_score_threshold=0.4)
if self.config.ndocs is None:
self.configure(ndocs=max(k * 4, 4096))
pids, scores = self.ranker.rank(self.config, Q, filter_fn=filter_fn, pids=pids)
return pids[:k], list(range(1, k+1)), scores[:k]