Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 2b32a92

Browse files
committed
Add lm-eval correctness test
1 parent 09f7161 commit 2b32a92

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

tests/accuracy/lm-eval-tasks.yaml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Llama 2 7B: FP16, FP16 sparse, marlin
2+
- model_name: "NousResearch/Llama-2-7b-chat-hf"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.2266868840030326
8+
- name: "exact_match,flexible-extract"
9+
value: 0.22820318423047764
10+
- model_name: "neuralmagic/Llama-2-7b-pruned50-retrained-ultrachat"
11+
tasks:
12+
- name: "gsm8k"
13+
metrics:
14+
- name: "exact_match,strict-match"
15+
value: 0.09855951478392722
16+
- name: "exact_match,flexible-extract"
17+
value: 0.10083396512509477
18+
extra_args:
19+
--sparsity: "sparse_w16a16"
20+
- model_name: "neuralmagic/llama-2-7b-chat-marlin"
21+
tasks:
22+
- name: "gsm8k"
23+
metrics:
24+
- name: "exact_match,strict-match"
25+
value: 0.14101592115238817
26+
- name: "exact_match,flexible-extract"
27+
value: 0.1652767247915087
28+
# Mistral 7B: FP16, FP16 sparse, marlin
29+
- model_name: "teknium/OpenHermes-2.5-Mistral-7B"
30+
tasks:
31+
- name: "gsm8k"
32+
metrics:
33+
- name: "exact_match,strict-match"
34+
value: 0.6004548900682335
35+
- name: "exact_match,flexible-extract"
36+
value: 0.6482183472327521
37+
- model_name: "neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50"
38+
tasks:
39+
- name: "gsm8k"
40+
metrics:
41+
- name: "exact_match,strict-match"
42+
value: 0.4935557240333586
43+
- name: "exact_match,flexible-extract"
44+
value: 0.5269143290371494
45+
extra_args:
46+
--sparsity: "sparse_w16a16"
47+
- model_name: "neuralmagic/OpenHermes-2.5-Mistral-7B-marlin"
48+
tasks:
49+
- name: "gsm8k"
50+
metrics:
51+
- name: "exact_match,strict-match"
52+
value: 0.4935557240333586
53+
- name: "exact_match,flexible-extract"
54+
value: 0.5868081880212282
55+
# Phi 2: marlin
56+
- model_name: "neuralmagic/phi-2-super-marlin"
57+
tasks:
58+
- name: "gsm8k"
59+
metrics:
60+
- name: "exact_match,strict-match"
61+
value: 0.49962092494313876
62+
- name: "exact_match,flexible-extract"
63+
value: 0.5041698256254739
64+
# Mixtral: FP16
65+
- model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1"
66+
tasks:
67+
- name: "gsm8k"
68+
metrics:
69+
- name: "exact_match,strict-match"
70+
value: 0.6550416982562547
71+
- name: "exact_match,flexible-extract"
72+
value: 0.6603487490523123
73+
enable_tensor_parallel: true
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Any, Dict, List, TypedDict
4+
5+
import lm_eval
6+
import numpy
7+
import pytest
8+
import torch
9+
import yaml
10+
11+
from tests.utils.server import ServerContext
12+
13+
14+
class Metric(TypedDict):
15+
name: str
16+
value: float
17+
18+
19+
class Task(TypedDict):
20+
name: str
21+
metrics: List[Metric]
22+
23+
24+
# to support python3.8 typing prior to adding `Required`/`NotRequired`, this class
25+
# stores the optional keys and the `EvalDefinition` subclass inherits those alongside
26+
# the required keys it defines.
27+
class EvalTaskDefinitionOpts(TypedDict, total=False):
28+
enable_tensor_parallel: bool
29+
extra_args: Dict[str, Any]
30+
31+
32+
class EvalTaskDefinition(EvalTaskDefinitionOpts):
33+
model_name: str
34+
tasks: List[Task]
35+
36+
37+
TEST_DATA_FILE = Path(__file__).parent / "lm-eval-tasks.yaml"
38+
TEST_DATA = yaml.safe_load(TEST_DATA_FILE.read_text(encoding="utf-8"))
39+
TEST_DATA: List[EvalTaskDefinition] = [
40+
pytest.param(eval_def, id=eval_def["model_name"]) for eval_def in TEST_DATA
41+
]
42+
43+
44+
@pytest.mark.parametrize("eval_data", TEST_DATA)
45+
def test_lm_eval_correctness(
46+
eval_data: EvalTaskDefinition,
47+
logger: logging.Logger,
48+
monkeypatch: pytest.MonkeyPatch,
49+
):
50+
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "false")
51+
monkeypatch.setenv("OPENAI_API_KEY", "dummy")
52+
53+
model_name = eval_data["model_name"]
54+
logger.info("building server startup args")
55+
vllm_args = {"--model": model_name, "--disable-log-requests": None}
56+
57+
if eval_data.get("enable_tensor_parallel") is True:
58+
tp = torch.cuda.device_count()
59+
logger.info("Enabling tensor parallelism with %d devices", tp)
60+
vllm_args["--tensor-parallel-size"] = tp
61+
62+
if extra_args := eval_data.get("extra_args"):
63+
vllm_args.update(extra_args)
64+
65+
openai_args = ",".join(
66+
[
67+
f"model={model_name}",
68+
"tokenizer_backend=huggingface",
69+
"base_url=http://localhost:8000/v1",
70+
]
71+
)
72+
73+
logger.info("launching server")
74+
with ServerContext(vllm_args, logger=logger) as _:
75+
task_names = [t["name"] for t in eval_data["tasks"]]
76+
logger.info("getting results for task_names=%s", task_names)
77+
results = lm_eval.simple_evaluate(
78+
model="local-completions",
79+
model_args=openai_args,
80+
tasks=task_names,
81+
batch_size=64,
82+
)
83+
84+
logger.info("clearing torch cache")
85+
lm_eval.models.utils.clear_torch_cache()
86+
87+
for task in eval_data["tasks"]:
88+
logger.info("checking metrics for task=%s", task["name"])
89+
for metric in task["metrics"]:
90+
ground_truth = metric["value"]
91+
measured_value = results["results"][task["name"]][metric["name"]]
92+
logger.info(
93+
"%s %s:\nground_truth=%s measured_value=%s",
94+
task["name"],
95+
metric["name"],
96+
ground_truth,
97+
measured_value,
98+
)
99+
100+
# Metrics must be within 1% of the larger of the two values. This
101+
# corresponds to a 99% accuracy threshold.
102+
assert numpy.isclose(ground_truth, measured_value, rtol=0.01)

0 commit comments

Comments
 (0)