-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_math.py
136 lines (122 loc) · 5.19 KB
/
eval_math.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import argparse
import json
import pdb
import jsonlines
from vllm import LLM, SamplingParams
from src import utils
import sys
from tqdm import trange
import torch._dynamo; torch._dynamo.config.suppress_errors = True
MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"
invalid_outputs = []
def remove_boxed(s):
left = "\\boxed{"
try:
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
except:
return None
def process_results(doc, completion, answer):
split_ans = completion.split('The answer is: ')
if len(split_ans) > 1:
ans = split_ans[-1]
extract_ans_temp = ans.split('.\n')[0]
extract_ans_temp = extract_ans_temp.strip()
if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
extract_ans = extract_ans_temp[0:-1]
else:
extract_ans = extract_ans_temp
extract_ans = extract_ans.strip()
if utils.is_equiv(extract_ans, answer):
return True
else:
return False
else:
temp = {'question': doc, 'output': completion, 'answer': answer}
invalid_outputs.append(temp)
return False
def batch_data(data_list, batch_size=1):
n = len(data_list) // batch_size
batch_data = []
for i in range(n-1):
start = i * batch_size
end = (i+1)*batch_size
batch_data.append(data_list[start:end])
last_start = (n-1) * batch_size
last_end = MAX_INT
batch_data.append(data_list[last_start:last_end])
return batch_data
def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1, filepath_output=None):
if filepath_output is None:
filepath_output = '/'.join(model.split('/')[:-1]) + "/" + "result_math.txt"
print(f"Result file will be dumped to {filepath_output}")
hendrycks_math_ins = []
hendrycks_math_answers = []
problem_prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)
print('prompt =====', problem_prompt)
with open(data_path, "r+", encoding="utf8") as f:
for idx, item in enumerate(jsonlines.Reader(f)):
temp_instr = problem_prompt.format(instruction=item["instruction"])
hendrycks_math_ins.append(temp_instr)
solution = item['output']
temp_ans = remove_boxed(utils.last_boxed_only_string(solution))
hendrycks_math_answers.append(temp_ans)
print('total length ===', len(hendrycks_math_ins))
hendrycks_math_ins = hendrycks_math_ins[start:end]
hendrycks_math_answers = hendrycks_math_answers[start:end]
print('length ====', len(hendrycks_math_ins))
batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
res_completions = []
for idx in trange(len(batch_hendrycks_math_ins), desc="Predicting on MATH dataset..."):
prompt = batch_hendrycks_math_ins[idx]
if isinstance(prompt, list):
pass
else:
prompt = [prompt]
completions = llm.generate(prompt, sampling_params, use_tqdm=False)
for output in completions:
prompt_temp = output.prompt
generated_text = output.outputs[0].text
res_completions.append(generated_text)
results = []
for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):
res = process_results(prompt, completion, prompt_answer)
results.append(res)
acc = sum(results) / len(results)
print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
print('start===', start, ', end====',end)
print('length====', len(results), ', acc====', acc)
with open(filepath_output, "w") as f:
f.write(f"{acc:.5f}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default='') # model path
parser.add_argument("--data_file", type=str, default='') # data path
parser.add_argument("--start", type=int, default=0) #start index
parser.add_argument("--end", type=int, default=MAX_INT) # end index
parser.add_argument("--batch_size", type=int, default=400) # batch_size
parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
parser.add_argument("--filepath_output", type=str, default="result_math.txt")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
test_hendrycks_math(
model=args.model,
data_path=args.data_file,
start=args.start,
end=args.end,
batch_size=args.batch_size,
tensor_parallel_size=args.tensor_parallel_size,
filepath_output=args.filepath_output
)