This repository was archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 9
Add lm-eval correctness test #210
Merged
Merged
Changes from 12 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
952a0db
Add test framework for server
dbarbuzzi 6178aea
Update docstring
dbarbuzzi c13b5c2
Add missing '__init__.py'
dbarbuzzi df48eef
In-line updated` ServerRunner` implementation
dbarbuzzi 09f7161
Restore logging of server command args
dbarbuzzi 2b32a92
Add lm-eval correctness test
dbarbuzzi 74d0293
Add "--max-model-len" arg
dbarbuzzi 4f6a5cf
Adjust relative tolerance value to 0.05
dbarbuzzi 7392992
Change '--max-model-len' to 2048
dbarbuzzi 3ebcc81
Fix comment length, remove outdated comment
dbarbuzzi a790a1f
Update comment
dbarbuzzi 431f051
Skip if `lm_eval` is not available
dbarbuzzi 44d781f
Merge branch 'main' into add-lm-eval-correctness-test
dbarbuzzi 6856f24
Skip test in remote push jobs
dbarbuzzi dc33cee
Fix check in lm-eval smoke test
dbarbuzzi 9bf3a71
Update lm-eval smoke job to use prebuilt wheel
dbarbuzzi c914b36
Fix typing in test
dbarbuzzi da1adf2
Add lm-eval-full job on release runs
dbarbuzzi 473f8ee
Skip full test in nightly
dbarbuzzi f316375
Fix style
dbarbuzzi c61d6b2
Update eval task configs
dbarbuzzi 44df6ad
Add support for configurable `rtol`
dbarbuzzi 7a1ecdf
Mark 'chat-marlin' model as xfail
dbarbuzzi 49d115b
Use correct label for TEST-LM-EVAL-FULL
dbarbuzzi d6571d4
Only run full lm-eval on a weekly cadence
dbarbuzzi 3b25154
Update naming
dbarbuzzi 5308642
Add manual release workflow
dbarbuzzi 4471031
Remove xfail logic
dbarbuzzi e972635
Fix release workflow category
dbarbuzzi 73adc9f
Disable marlin models
dbarbuzzi 638d924
Separate nightly/weekly workflows
dbarbuzzi 9828633
Additional fix for lm-eval smoke check
dbarbuzzi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Llama 2 7B: FP16, FP16 sparse, marlin | ||
- model_name: "NousResearch/Llama-2-7b-chat-hf" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.2266868840030326 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.22820318423047764 | ||
- model_name: "neuralmagic/Llama-2-7b-pruned50-retrained-ultrachat" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.09855951478392722 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.10083396512509477 | ||
extra_args: | ||
--sparsity: "sparse_w16a16" | ||
- model_name: "neuralmagic/llama-2-7b-chat-marlin" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.14101592115238817 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.1652767247915087 | ||
# Mistral 7B: FP16, FP16 sparse, marlin | ||
- model_name: "teknium/OpenHermes-2.5-Mistral-7B" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.6004548900682335 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.6482183472327521 | ||
- model_name: "neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.4935557240333586 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.5269143290371494 | ||
extra_args: | ||
--sparsity: "sparse_w16a16" | ||
- model_name: "neuralmagic/OpenHermes-2.5-Mistral-7B-marlin" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.4935557240333586 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.5868081880212282 | ||
# Phi 2: marlin | ||
- model_name: "neuralmagic/phi-2-super-marlin" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.49962092494313876 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.5041698256254739 | ||
# Mixtral: FP16 | ||
- model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" | ||
tasks: | ||
- name: "gsm8k" | ||
metrics: | ||
- name: "exact_match,strict-match" | ||
value: 0.6550416982562547 | ||
- name: "exact_match,flexible-extract" | ||
value: 0.6603487490523123 | ||
enable_tensor_parallel: true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import logging | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Any, Dict, List, TypedDict | ||
|
||
import numpy | ||
import pytest | ||
import torch | ||
import yaml | ||
|
||
from tests.utils.server import ServerContext | ||
|
||
if TYPE_CHECKING: | ||
import lm_eval as lm_eval_t | ||
|
||
# requires a particular lm-evaluation-harness | ||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@9516087b81a61d0e220b22cc1b75be76de23bc10 | ||
lm_eval: lm_eval_t = pytest.importorskip("lm_eval", reason="lm_eval required") | ||
|
||
|
||
class Metric(TypedDict): | ||
name: str | ||
value: float | ||
|
||
|
||
class Task(TypedDict): | ||
name: str | ||
metrics: List[Metric] | ||
|
||
|
||
# to support python3.8 typing prior to adding `Required`/`NotRequired`, this | ||
# class stores the optional keys and the `EvalTaskDefinition` subclass inherits | ||
# those alongside the required keys it defines. | ||
class EvalTaskDefinitionOpts(TypedDict, total=False): | ||
enable_tensor_parallel: bool | ||
extra_args: Dict[str, Any] | ||
|
||
|
||
class EvalTaskDefinition(EvalTaskDefinitionOpts): | ||
model_name: str | ||
tasks: List[Task] | ||
|
||
|
||
TEST_DATA_FILE = Path(__file__).parent / "lm-eval-tasks.yaml" | ||
TEST_DATA = yaml.safe_load(TEST_DATA_FILE.read_text(encoding="utf-8")) | ||
TEST_DATA: List[EvalTaskDefinition] = [ | ||
pytest.param(eval_def, id=eval_def["model_name"]) for eval_def in TEST_DATA | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("eval_data", TEST_DATA) | ||
def test_lm_eval_correctness( | ||
eval_data: EvalTaskDefinition, | ||
logger: logging.Logger, | ||
monkeypatch: pytest.MonkeyPatch, | ||
): | ||
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "false") | ||
monkeypatch.setenv("OPENAI_API_KEY", "dummy") | ||
|
||
model_name = eval_data["model_name"] | ||
logger.info("building server startup args") | ||
vllm_args = { | ||
"--model": model_name, | ||
"--disable-log-requests": None, | ||
"--max-model-len": 2048, | ||
} | ||
|
||
if eval_data.get("enable_tensor_parallel") is True: | ||
tp = torch.cuda.device_count() | ||
logger.info("Enabling tensor parallelism with %d devices", tp) | ||
vllm_args["--tensor-parallel-size"] = tp | ||
|
||
if extra_args := eval_data.get("extra_args"): | ||
vllm_args.update(extra_args) | ||
|
||
openai_args = ",".join([ | ||
f"model={model_name}", | ||
"tokenizer_backend=huggingface", | ||
"base_url=http://localhost:8000/v1", | ||
]) | ||
|
||
logger.info("launching server") | ||
with ServerContext(vllm_args, logger=logger) as _: | ||
task_names = [t["name"] for t in eval_data["tasks"]] | ||
logger.info("getting results for task_names=%s", task_names) | ||
results = lm_eval.simple_evaluate( | ||
model="local-completions", | ||
model_args=openai_args, | ||
tasks=task_names, | ||
batch_size=64, | ||
) | ||
|
||
logger.info("clearing torch cache") | ||
lm_eval.models.utils.clear_torch_cache() | ||
|
||
for task in eval_data["tasks"]: | ||
logger.info("checking metrics for task=%s", task["name"]) | ||
for metric in task["metrics"]: | ||
ground_truth = metric["value"] | ||
measured_value = results["results"][task["name"]][metric["name"]] | ||
logger.info( | ||
"%s %s:\nground_truth=%s measured_value=%s", | ||
task["name"], | ||
metric["name"], | ||
ground_truth, | ||
measured_value, | ||
) | ||
|
||
assert numpy.isclose(ground_truth, measured_value, rtol=0.05) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import logging | ||
|
||
|
||
def make_logger(name: str) -> logging.Logger: | ||
"""Create a base logger""" | ||
|
||
logger = logging.getLogger(name) | ||
logger.setLevel(logging.DEBUG) | ||
stream_handler = logging.StreamHandler() | ||
formatter = logging.Formatter( | ||
"%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
stream_handler.setFormatter(formatter) | ||
logger.addHandler(stream_handler) | ||
return logger | ||
|
||
|
||
def log_banner(logger: logging.Logger, | ||
label: str, | ||
body: str, | ||
level: int = logging.INFO): | ||
""" | ||
Log a message in the "banner"-style format. | ||
|
||
:param logger: Instance of "logging.Logger" to use | ||
:param label: Label for the top of the banner | ||
:param body: Body content inside the banner | ||
:param level: Logging level to use (default: INFO) | ||
""" | ||
|
||
banner = f"==== {label} ====\n{body}\n====" | ||
logger.log(level, "\n%s", banner) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import logging | ||
import os | ||
import shlex | ||
import subprocess | ||
import sys | ||
import time | ||
from typing import Any, Dict, List, Optional | ||
|
||
import ray | ||
import requests | ||
import torch | ||
|
||
from tests.utils.logging import log_banner | ||
|
||
MAX_SERVER_START_WAIT = 600 # time (seconds) to wait for server to start | ||
|
||
|
||
@ray.remote(num_gpus=torch.cuda.device_count()) | ||
class ServerRunner: | ||
|
||
def __init__(self, | ||
args: List[str], | ||
*, | ||
logger: Optional[logging.Logger] = None): | ||
env = os.environ.copy() | ||
env["PYTHONUNBUFFERED"] = "1" | ||
self.startup_command = [ | ||
sys.executable, | ||
"-m", | ||
"vllm.entrypoints.openai.api_server", | ||
*args, | ||
] | ||
|
||
if logger: | ||
log_banner( | ||
logger, | ||
"server startup command", | ||
shlex.join(self.startup_command), | ||
logging.DEBUG, | ||
) | ||
|
||
self.proc = subprocess.Popen( | ||
[ | ||
sys.executable, "-m", "vllm.entrypoints.openai.api_server", | ||
*args | ||
], | ||
env=env, | ||
stdout=sys.stdout, | ||
stderr=sys.stderr, | ||
) | ||
self._wait_for_server() | ||
|
||
def ready(self): | ||
return True | ||
|
||
def _wait_for_server(self): | ||
# run health check | ||
start = time.time() | ||
while True: | ||
try: | ||
if requests.get( | ||
"http://localhost:8000/health").status_code == 200: | ||
break | ||
except Exception as err: | ||
if self.proc.poll() is not None: | ||
raise RuntimeError("Server exited unexpectedly.") from err | ||
|
||
time.sleep(0.5) | ||
if time.time() - start > MAX_SERVER_START_WAIT: | ||
raise RuntimeError( | ||
"Server failed to start in time.") from err | ||
|
||
def __del__(self): | ||
if hasattr(self, "proc"): | ||
self.proc.terminate() | ||
|
||
|
||
class ServerContext: | ||
""" | ||
Context manager for the lifecycle of a vLLM server, wrapping `ServerRunner`. | ||
""" | ||
|
||
def __init__(self, args: Dict[str, str], *, | ||
logger: logging.Logger) -> None: | ||
"""Initialize a vLLM server | ||
|
||
:param args: dictionary of flags/values to pass to the server command | ||
:param logger: logging.Logger instance to use for logging | ||
:param port: port the server is running on | ||
""" | ||
self._args = self._args_to_list(args) | ||
self._logger = logger | ||
self.server_runner = None | ||
|
||
def __enter__(self): | ||
"""Executes the server process and waits for it to become ready.""" | ||
ray.init(ignore_reinit_error=True) | ||
log_banner(self._logger, "server startup command args", | ||
shlex.join(self._args)) | ||
self.server_runner = ServerRunner.remote(self._args, | ||
logger=self._logger) | ||
ray.get(self.server_runner.ready.remote()) | ||
return self.server_runner | ||
|
||
def __exit__(self, exc_type, exc_value, exc_traceback): | ||
""" | ||
Stops the server if it's still running. | ||
""" | ||
if self.server_runner is not None: | ||
del self.server_runner | ||
ray.shutdown() | ||
|
||
def _args_to_list(self, args: Dict[str, Any]) -> List[str]: | ||
""" | ||
Convert a dict mapping of CLI args to a list. All values must be | ||
string-able. | ||
|
||
:param args: `dict` containing CLI flags and their values | ||
:return: flattened list to pass to a CLI | ||
""" | ||
|
||
arg_list: List[str] = [] | ||
for flag, value in args.items(): | ||
# minimal error-checking: flag names must be strings | ||
if not isinstance(flag, str): | ||
error = f"all flags must be strings, got {type(flag)} ({flag})" | ||
raise ValueError(error) | ||
|
||
arg_list.append(flag) | ||
if value is not None: | ||
arg_list.append(str(value)) | ||
|
||
return arg_list |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be nice to collect results for all of the metrics, and maybe all of the tasks, that are not close, then assert if there are any in error. that way all of the problems are reported, rather than having a developer fix one issue then getting an error on the next that they didn't fix yet.