Skip to content

Commit 130bfc4

Browse files
committed
Add new evaluation metrics
1 parent 26e790d commit 130bfc4

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

torchao/_models/llama/eval.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,20 @@ def run_evaluation(
4242
calibration_limit: Optional[int] = None,
4343
calibration_seq_length: Optional[int] = None,
4444
pad_calibration_inputs: Optional[bool] = False,
45+
eval_difficulty: Optional[str] = "easy",
4546
):
4647
"""Runs the evaluation of a model using LM Eval."""
48+
49+
# Select eval tasks based on difficulty level
50+
if eval_difficulty == "medium":
51+
tasks.extend(['mmlu'])
52+
elif eval_difficulty == "hard":
53+
tasks.extend(['mmlu', 'truthfulqa_mc2', 'winogrande', 'arc_challenge', 'hellaswag', 'gsm8k'])
54+
4755
print(
4856
f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, "
4957
+f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
50-
+f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n"
58+
+f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}, eval_difficulty: {eval_difficulty}\n"
5159
)
5260
torchao.quantization.utils.recommended_inductor_config_setter()
5361

@@ -218,6 +226,7 @@ def run_evaluation(
218226
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
219227
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
220228
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
229+
parser.add_argument('-d', '--eval_difficulty', type=str, default="easy", help='difficulty of eval, one of [easy, medium, hard]')
221230

222231
args = parser.parse_args()
223232
run_evaluation(
@@ -233,4 +242,5 @@ def run_evaluation(
233242
args.calibration_limit,
234243
args.calibration_seq_length,
235244
args.pad_calibration_inputs,
245+
args.eval_difficulty,
236246
)

0 commit comments

Comments
 (0)