forked from ContextualAI/gritlm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
merge_cqadupstack.py
68 lines (60 loc) · 2.35 KB
/
merge_cqadupstack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Merges CQADupstack subset results
Usage: python merge_cqadupstack.py path_to_results_folder
"""
import glob
import json
import logging
import os
import sys
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
TASK_LIST_CQA = [
"CQADupstackAndroidRetrieval",
"CQADupstackEnglishRetrieval",
"CQADupstackGamingRetrieval",
"CQADupstackGisRetrieval",
"CQADupstackMathematicaRetrieval",
"CQADupstackPhysicsRetrieval",
"CQADupstackProgrammersRetrieval",
"CQADupstackStatsRetrieval",
"CQADupstackTexRetrieval",
"CQADupstackUnixRetrieval",
"CQADupstackWebmastersRetrieval",
"CQADupstackWordpressRetrieval",
]
NOAVG_KEYS = [
"evaluation_time",
"mteb_version",
"mteb_dataset_name",
"dataset_revision",
]
results_folder = sys.argv[1]
# Ensure at least 1 character btw CQADupstack & Retrieval
files = glob.glob(f'{results_folder.rstrip("/")}/CQADupstack*?*Retrieval.json')
logger.info(f"Found CQADupstack files: {files}")
if len(files) == len(TASK_LIST_CQA):
all_results = {}
for file_name in files:
with open(file_name, "r", encoding="utf-8") as f:
results = json.load(f)
for split, split_results in results.items():
if split not in ("train", "validation", "dev", "test"):
all_results[split] = split_results
continue
all_results.setdefault(split, {})
for metric, score in split_results.items():
all_results[split].setdefault(metric, 0)
if metric == "evaluation_time":
score = all_results[split][metric] + score
elif metric not in NOAVG_KEYS:
score = all_results[split][metric] + score * 1 / len(TASK_LIST_CQA)
all_results[split][metric] = score
all_results["mteb_dataset_name"] = "CQADupstackRetrieval"
logger.info("Saving ", all_results)
with open(os.path.join(results_folder, "CQADupstackRetrieval.json"), "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=4)
else:
logger.warning(
f"Got {len(files)}, but expected {len(TASK_LIST_CQA)} files. Missing: {set(TASK_LIST_CQA) - set([x.split('/')[-1].split('.')[0] for x in files])}; Too much: {set([x.split('/')[-1].split('.')[0] for x in files]) - set(TASK_LIST_CQA)}"
)