Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Add lm-eval correctness test #210

Merged
merged 32 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
952a0db
Add test framework for server
dbarbuzzi Apr 22, 2024
6178aea
Update docstring
dbarbuzzi Apr 22, 2024
c13b5c2
Add missing '__init__.py'
dbarbuzzi Apr 22, 2024
df48eef
In-line updated` ServerRunner` implementation
dbarbuzzi Apr 23, 2024
09f7161
Restore logging of server command args
dbarbuzzi Apr 24, 2024
2b32a92
Add lm-eval correctness test
dbarbuzzi Apr 24, 2024
74d0293
Add "--max-model-len" arg
dbarbuzzi Apr 24, 2024
4f6a5cf
Adjust relative tolerance value to 0.05
dbarbuzzi Apr 24, 2024
7392992
Change '--max-model-len' to 2048
dbarbuzzi Apr 25, 2024
3ebcc81
Fix comment length, remove outdated comment
dbarbuzzi Apr 25, 2024
a790a1f
Update comment
dbarbuzzi Apr 25, 2024
431f051
Skip if `lm_eval` is not available
dbarbuzzi Apr 29, 2024
44d781f
Merge branch 'main' into add-lm-eval-correctness-test
dbarbuzzi May 3, 2024
6856f24
Skip test in remote push jobs
dbarbuzzi May 3, 2024
dc33cee
Fix check in lm-eval smoke test
dbarbuzzi May 3, 2024
9bf3a71
Update lm-eval smoke job to use prebuilt wheel
dbarbuzzi May 3, 2024
c914b36
Fix typing in test
dbarbuzzi May 3, 2024
da1adf2
Add lm-eval-full job on release runs
dbarbuzzi May 3, 2024
473f8ee
Skip full test in nightly
dbarbuzzi May 3, 2024
f316375
Fix style
dbarbuzzi May 3, 2024
c61d6b2
Update eval task configs
dbarbuzzi May 3, 2024
44df6ad
Add support for configurable `rtol`
dbarbuzzi May 3, 2024
7a1ecdf
Mark 'chat-marlin' model as xfail
dbarbuzzi May 3, 2024
49d115b
Use correct label for TEST-LM-EVAL-FULL
dbarbuzzi May 3, 2024
d6571d4
Only run full lm-eval on a weekly cadence
dbarbuzzi May 6, 2024
3b25154
Update naming
dbarbuzzi May 6, 2024
5308642
Add manual release workflow
dbarbuzzi May 7, 2024
4471031
Remove xfail logic
dbarbuzzi May 7, 2024
e972635
Fix release workflow category
dbarbuzzi May 7, 2024
73adc9f
Disable marlin models
dbarbuzzi May 9, 2024
638d924
Separate nightly/weekly workflows
dbarbuzzi May 9, 2024
9828633
Additional fix for lm-eval smoke check
dbarbuzzi May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/accuracy/lm-eval-tasks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Llama 2 7B: FP16, FP16 sparse, marlin
- model_name: "NousResearch/Llama-2-7b-chat-hf"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.2266868840030326
- name: "exact_match,flexible-extract"
value: 0.22820318423047764
- model_name: "neuralmagic/Llama-2-7b-pruned50-retrained-ultrachat"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.09855951478392722
- name: "exact_match,flexible-extract"
value: 0.10083396512509477
extra_args:
--sparsity: "sparse_w16a16"
- model_name: "neuralmagic/llama-2-7b-chat-marlin"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.14101592115238817
- name: "exact_match,flexible-extract"
value: 0.1652767247915087
# Mistral 7B: FP16, FP16 sparse, marlin
- model_name: "teknium/OpenHermes-2.5-Mistral-7B"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6004548900682335
- name: "exact_match,flexible-extract"
value: 0.6482183472327521
- model_name: "neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.4935557240333586
- name: "exact_match,flexible-extract"
value: 0.5269143290371494
extra_args:
--sparsity: "sparse_w16a16"
- model_name: "neuralmagic/OpenHermes-2.5-Mistral-7B-marlin"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.4935557240333586
- name: "exact_match,flexible-extract"
value: 0.5868081880212282
# Phi 2: marlin
- model_name: "neuralmagic/phi-2-super-marlin"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.49962092494313876
- name: "exact_match,flexible-extract"
value: 0.5041698256254739
# Mixtral: FP16
- model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6550416982562547
- name: "exact_match,flexible-extract"
value: 0.6603487490523123
enable_tensor_parallel: true
108 changes: 108 additions & 0 deletions tests/accuracy/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, TypedDict

import numpy
import pytest
import torch
import yaml

from tests.utils.server import ServerContext

if TYPE_CHECKING:
import lm_eval as lm_eval_t

# requires a particular lm-evaluation-harness
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@9516087b81a61d0e220b22cc1b75be76de23bc10
lm_eval: lm_eval_t = pytest.importorskip("lm_eval", reason="lm_eval required")


class Metric(TypedDict):
name: str
value: float


class Task(TypedDict):
name: str
metrics: List[Metric]


# to support python3.8 typing prior to adding `Required`/`NotRequired`, this
# class stores the optional keys and the `EvalTaskDefinition` subclass inherits
# those alongside the required keys it defines.
class EvalTaskDefinitionOpts(TypedDict, total=False):
enable_tensor_parallel: bool
extra_args: Dict[str, Any]


class EvalTaskDefinition(EvalTaskDefinitionOpts):
model_name: str
tasks: List[Task]


TEST_DATA_FILE = Path(__file__).parent / "lm-eval-tasks.yaml"
TEST_DATA = yaml.safe_load(TEST_DATA_FILE.read_text(encoding="utf-8"))
TEST_DATA: List[EvalTaskDefinition] = [
pytest.param(eval_def, id=eval_def["model_name"]) for eval_def in TEST_DATA
]


@pytest.mark.parametrize("eval_data", TEST_DATA)
def test_lm_eval_correctness(
eval_data: EvalTaskDefinition,
logger: logging.Logger,
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "false")
monkeypatch.setenv("OPENAI_API_KEY", "dummy")

model_name = eval_data["model_name"]
logger.info("building server startup args")
vllm_args = {
"--model": model_name,
"--disable-log-requests": None,
"--max-model-len": 2048,
}

if eval_data.get("enable_tensor_parallel") is True:
tp = torch.cuda.device_count()
logger.info("Enabling tensor parallelism with %d devices", tp)
vllm_args["--tensor-parallel-size"] = tp

if extra_args := eval_data.get("extra_args"):
vllm_args.update(extra_args)

openai_args = ",".join([
f"model={model_name}",
"tokenizer_backend=huggingface",
"base_url=http://localhost:8000/v1",
])

logger.info("launching server")
with ServerContext(vllm_args, logger=logger) as _:
task_names = [t["name"] for t in eval_data["tasks"]]
logger.info("getting results for task_names=%s", task_names)
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=openai_args,
tasks=task_names,
batch_size=64,
)

logger.info("clearing torch cache")
lm_eval.models.utils.clear_torch_cache()

for task in eval_data["tasks"]:
logger.info("checking metrics for task=%s", task["name"])
for metric in task["metrics"]:
ground_truth = metric["value"]
measured_value = results["results"][task["name"]][metric["name"]]
logger.info(
"%s %s:\nground_truth=%s measured_value=%s",
task["name"],
metric["name"],
ground_truth,
measured_value,
)

assert numpy.isclose(ground_truth, measured_value, rtol=0.05)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be nice to collect results for all of the metrics, and maybe all of the tasks, that are not close, then assert if there are any in error. that way all of the problems are reported, rather than having a developer fix one issue then getting an error on the next that they didn't fix yet.

7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import gc
import logging
import os
from typing import List, Optional, Tuple

Expand All @@ -9,6 +10,7 @@
from transformers import (AutoModelForCausalLM, AutoProcessor,
LlavaForConditionalGeneration)

from tests.utils.logging import make_logger
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
Expand Down Expand Up @@ -547,3 +549,8 @@ def get_tokenizer_pool_config(tokenizer_group_type):
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")


@pytest.fixture(scope="session")
def logger() -> logging.Logger:
return make_logger("vllm_test")
Empty file added tests/utils/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions tests/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging


def make_logger(name: str) -> logging.Logger:
"""Create a base logger"""

logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger


def log_banner(logger: logging.Logger,
label: str,
body: str,
level: int = logging.INFO):
"""
Log a message in the "banner"-style format.

:param logger: Instance of "logging.Logger" to use
:param label: Label for the top of the banner
:param body: Body content inside the banner
:param level: Logging level to use (default: INFO)
"""

banner = f"==== {label} ====\n{body}\n===="
logger.log(level, "\n%s", banner)
133 changes: 133 additions & 0 deletions tests/utils/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import logging
import os
import shlex
import subprocess
import sys
import time
from typing import Any, Dict, List, Optional

import ray
import requests
import torch

from tests.utils.logging import log_banner

MAX_SERVER_START_WAIT = 600 # time (seconds) to wait for server to start


@ray.remote(num_gpus=torch.cuda.device_count())
class ServerRunner:

def __init__(self,
args: List[str],
*,
logger: Optional[logging.Logger] = None):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.startup_command = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
*args,
]

if logger:
log_banner(
logger,
"server startup command",
shlex.join(self.startup_command),
logging.DEBUG,
)

self.proc = subprocess.Popen(
[
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
*args
],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()

def ready(self):
return True

def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err

time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT:
raise RuntimeError(
"Server failed to start in time.") from err

def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()


class ServerContext:
"""
Context manager for the lifecycle of a vLLM server, wrapping `ServerRunner`.
"""

def __init__(self, args: Dict[str, str], *,
logger: logging.Logger) -> None:
"""Initialize a vLLM server

:param args: dictionary of flags/values to pass to the server command
:param logger: logging.Logger instance to use for logging
:param port: port the server is running on
"""
self._args = self._args_to_list(args)
self._logger = logger
self.server_runner = None

def __enter__(self):
"""Executes the server process and waits for it to become ready."""
ray.init(ignore_reinit_error=True)
log_banner(self._logger, "server startup command args",
shlex.join(self._args))
self.server_runner = ServerRunner.remote(self._args,
logger=self._logger)
ray.get(self.server_runner.ready.remote())
return self.server_runner

def __exit__(self, exc_type, exc_value, exc_traceback):
"""
Stops the server if it's still running.
"""
if self.server_runner is not None:
del self.server_runner
ray.shutdown()

def _args_to_list(self, args: Dict[str, Any]) -> List[str]:
"""
Convert a dict mapping of CLI args to a list. All values must be
string-able.

:param args: `dict` containing CLI flags and their values
:return: flattened list to pass to a CLI
"""

arg_list: List[str] = []
for flag, value in args.items():
# minimal error-checking: flag names must be strings
if not isinstance(flag, str):
error = f"all flags must be strings, got {type(flag)} ({flag})"
raise ValueError(error)

arg_list.append(flag)
if value is not None:
arg_list.append(str(value))

return arg_list