forked from stanford-futuredata/ColBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
queries.py
163 lines (115 loc) · 4.29 KB
/
queries.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from colbert.infra.run import Run
import os
import ujson
from colbert.evaluation.loaders import load_queries
# TODO: Look up path in some global [per-thread or thread-safe] list.
# TODO: path could be a list of paths...? But then how can we tell it's not a list of queries..
class Queries:
def __init__(self, path=None, data=None):
self.path = path
if data:
assert isinstance(data, dict), type(data)
self._load_data(data) or self._load_file(path)
def __len__(self):
return len(self.data)
def __iter__(self):
return iter(self.data.items())
def provenance(self):
return self.path
def toDict(self):
return {'provenance': self.provenance()}
def _load_data(self, data):
if data is None:
return None
self.data = {}
self._qas = {}
for qid, content in data.items():
if isinstance(content, dict):
self.data[qid] = content['question']
self._qas[qid] = content
else:
self.data[qid] = content
if len(self._qas) == 0:
del self._qas
return True
def _load_file(self, path):
if not path.endswith('.json'):
self.data = load_queries(path)
return True
# Load QAs
self.data = {}
self._qas = {}
with open(path) as f:
for line in f:
qa = ujson.loads(line)
assert qa['qid'] not in self.data
self.data[qa['qid']] = qa['question']
self._qas[qa['qid']] = qa
return self.data
def qas(self):
return dict(self._qas)
def __getitem__(self, key):
return self.data[key]
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def save(self, new_path):
assert new_path.endswith('.tsv')
assert not os.path.exists(new_path), new_path
with Run().open(new_path, 'w') as f:
for qid, content in self.data.items():
content = f'{qid}\t{content}\n'
f.write(content)
return f.name
def save_qas(self, new_path):
assert new_path.endswith('.json')
assert not os.path.exists(new_path), new_path
with open(new_path, 'w') as f:
for qid, qa in self._qas.items():
qa['qid'] = qid
f.write(ujson.dumps(qa) + '\n')
def _load_tsv(self, path):
raise NotImplementedError
def _load_jsonl(self, path):
raise NotImplementedError
@classmethod
def cast(cls, obj):
if type(obj) is str:
return cls(path=obj)
if isinstance(obj, dict) or isinstance(obj, list):
return cls(data=obj)
if type(obj) is cls:
return obj
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
# class QuerySet:
# def __init__(self, *paths, renumber=False):
# self.paths = paths
# self.original_queries = [load_queries(path) for path in paths]
# if renumber:
# self.queries = flatten([q.values() for q in self.original_queries])
# self.queries = {idx: text for idx, text in enumerate(self.queries)}
# else:
# self.queries = {}
# for queries in self.original_queries:
# assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \
# "renumber=False requires non-overlapping query IDs"
# self.queries.update(queries)
# assert len(self.queries) == sum(map(len, self.original_queries))
# def todict(self):
# return dict(self.queries)
# def tolist(self):
# return list(self.queries.values())
# def query_sets(self):
# return self.original_queries
# def split_rankings(self, rankings):
# assert type(rankings) is list
# assert len(rankings) == len(self.queries)
# sub_rankings = []
# offset = 0
# for source in self.original_queries:
# sub_rankings.append(rankings[offset:offset+len(source)])
# offset += len(source)
# return sub_rankings