-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrec_test.py
More file actions
63 lines (56 loc) · 1.97 KB
/
trec_test.py
File metadata and controls
63 lines (56 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import json
def get_mrr(qrels: str, trec: str, qrelss, metric: str = 'mrr_cut_10') -> float:
k = int(metric.split('_')[-1])
qrel = {}
with open(qrels, 'r') as f_qrel:
for line in f_qrel:
qid, _, did, label = line.strip().split()
if qid not in qrel:
qrel[qid] = {}
qrel[qid][did] = 1
qrell = {}
with open(qrelss, 'r') as f_qrel:
for line in f_qrel:
qid, _, did, label = line.strip().split()
if qid not in qrell:
qrell[qid] = {}
qrell[qid][did] = 1
run = {}
with open(trec, 'r') as f_run:
for line in f_run:
qid, _, did, _, _, _ = line.strip().split()
if qid not in run:
run[qid] = []
run[qid].append(did)
mrr = 0.0
mrrr = 0.0
for qid in run:
rr = 0.0
rrr = 0.0
if qid == '1048585':
ans = 1
for i, did in enumerate(run[qid][:k]):
if qid in qrel and did in qrel[qid] and qrel[qid][did] > 0:
rr = 1 / (i+1)
#break
if qid in qrell and did in qrell[qid] and qrell[qid][did] > 0:
rrr = 1 / (i+1)
break
mrr += rr
mrrr += rrr
if mrr != mrrr:
print(qid)
print(qrel[qid])
print(qrell[qid])
print(mrr)
print(mrrr)
break
mrr /= len(run)
return mrr
if __name__ == '__main__':
passage_trec = "/checkpoints_local/t5/v10_passage_global_monot5_detach_fp16_0828/v10_passage_global_monot5_detach_fp16_0828-130000-test-bm25.trec"
document_trec = "/checkpoints_local/t5/v10_global_detach_fp16_0607/v10_global_detach_fp16_0607-80000-test.trec"
passage_qrel = "/dataset/msmarco/passage/qrels.dev.small.tsv"
passage_qrel_ours = "/dataset/msmarco/passage/dev_ys_full.trec"
mrr10 = get_mrr(passage_qrel, passage_trec, passage_qrel_ours)
print(mrr10)