Skip to content

Commit ce8880f

Browse files
authored
Merge pull request #185 from danmcp/answersunittests
Add reorg answer file test
2 parents e5d89c6 + 894e658 commit ce8880f

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

src/instructlab/eval/mt_bench_answers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,19 @@
2626
def reorg_answer_file(answer_file):
2727
"""Sort by question id and de-duplication"""
2828
logger.debug(locals())
29-
answers = {}
30-
with open(answer_file, "r", encoding="utf-8") as fin:
31-
for l in fin:
29+
with open(answer_file, "r+", encoding="utf-8") as f:
30+
answers = {}
31+
for l in f:
3232
qid = json.loads(l)["question_id"]
3333
answers[qid] = l
3434

35-
qids = sorted(list(answers.keys()))
36-
with open(answer_file, "w", encoding="utf-8") as fout:
35+
# Reset to the beginning of the file and clear it
36+
f.seek(0)
37+
f.truncate()
38+
39+
qids = sorted(list(answers.keys()))
3740
for qid in qids:
38-
fout.write(answers[qid])
41+
f.write(answers[qid])
3942

4043

4144
def get_answer(

tests/test_mt_bench_answers.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Standard
4+
import json
5+
import os
6+
import random
7+
import shutil
8+
import tempfile
9+
10+
# First Party
11+
from instructlab.eval.mt_bench_answers import reorg_answer_file
12+
13+
14+
def test_reorg_answer_file():
15+
answer_file = os.path.join(
16+
os.path.dirname(__file__),
17+
"..",
18+
"src",
19+
"instructlab",
20+
"eval",
21+
"data",
22+
"mt_bench",
23+
"reference_answer",
24+
"gpt-4.jsonl",
25+
)
26+
27+
# Create a temporary file
28+
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
29+
temp_answer_file = temp_file.name
30+
31+
# Copy the original file to the temp file
32+
shutil.copy(answer_file, temp_answer_file)
33+
34+
orig_length = 0
35+
with open(temp_answer_file, "r+", encoding="utf-8") as f:
36+
answers = {}
37+
for l in f:
38+
orig_length += 1
39+
qid = json.loads(l)["question_id"]
40+
answers[qid] = l
41+
42+
# Reset to the beginning of the file and clear it
43+
f.seek(0)
44+
f.truncate()
45+
46+
# Randomize the values
47+
qids = sorted(list(answers.keys()), key=lambda answer: random.random())
48+
for qid in qids:
49+
f.write(answers[qid])
50+
# Write each answer twice
51+
f.write(answers[qid])
52+
53+
# Run the reorg which should sort and dedup the file in place
54+
reorg_answer_file(temp_answer_file)
55+
56+
new_length = 0
57+
with open(temp_answer_file, "r", encoding="utf-8") as fin:
58+
previous_question_id = -1
59+
for l in fin:
60+
new_length += 1
61+
qid = json.loads(l)["question_id"]
62+
assert qid > previous_question_id
63+
previous_question_id = qid
64+
65+
assert new_length == orig_length

0 commit comments

Comments
 (0)