Skip to content

Commit 5a1803e

Browse files
committed
Add new evaluation metrics
1 parent 26e790d commit 5a1803e

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

torchao/_models/llama/eval.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def run_evaluation(
4242
calibration_limit: Optional[int] = None,
4343
calibration_seq_length: Optional[int] = None,
4444
pad_calibration_inputs: Optional[bool] = False,
45+
eval_difficulty: Optional[str] = "easy",
4546
):
4647
"""Runs the evaluation of a model using LM Eval."""
4748
print(
@@ -179,6 +180,12 @@ def run_evaluation(
179180
model.to(device)
180181
model.reset_caches()
181182

183+
# Select eval tasks based on difficulty level
184+
if eval_difficulty == "medium":
185+
tasks.extend(['mmlu'])
186+
elif eval_difficulty == "hard":
187+
tasks.extend(['mmlu', 'truthfulqa_mc2', 'winogrande', 'arc_challenge', 'hellaswag', 'gsm8k'])
188+
182189
if compile:
183190
model = torch.compile(model, mode="max-autotune", fullgraph=True)
184191
with torch.no_grad():
@@ -218,6 +225,7 @@ def run_evaluation(
218225
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
219226
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
220227
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
228+
parser.add_argument('--eval_difficulty', type=str, default="easy", help='difficulty of eval, one of [easy, medium, hard]')
221229

222230
args = parser.parse_args()
223231
run_evaluation(
@@ -233,4 +241,5 @@ def run_evaluation(
233241
args.calibration_limit,
234242
args.calibration_seq_length,
235243
args.pad_calibration_inputs,
244+
args.eval_difficulty,
236245
)

0 commit comments

Comments
 (0)