Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def run_eval():
elif "--lmms" in sys.argv:
sys.argv.remove("--lmms")
run_lmms()
elif "--vllm" in sys.argv:
sys.argv.remove("--vllm")
from auto_round.script.llm import eval_with_vllm, setup_eval_parser

args = setup_eval_parser()
eval_with_vllm(args)
else:
from auto_round.script.llm import setup_eval_parser

Expand Down
90 changes: 81 additions & 9 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
set_cuda_visible_devices,
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class BasicArgumentParser(argparse.ArgumentParser):

Expand Down Expand Up @@ -322,6 +324,27 @@ def __init__(self, *args, **kwargs):
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)
# vllm related arguments
self.add_argument("--revision", default=None, type=str, help="model revision for vllm")
self.add_argument("--tokenizer", default=None, type=str, help="tokenizer to use with vllm")
self.add_argument(
"--tokenizer_mode", default="auto", type=str, help="tokenizer mode for vllm (e.g. auto/fast/slow)"
)
self.add_argument("--tokenizer_revision", default=None, type=str, help="tokenizer revision for vllm")
self.add_argument("--add_bos_token", action="store_true", help="add BOS token when using vllm")
self.add_argument("--prefix_token_id", default=None, type=int, help="prefix token id for vllm")
self.add_argument("--tensor_parallel_size", default=1, type=int, help="tensor parallel size for vllm")
self.add_argument("--data_parallel_size", default=1, type=int, help="data parallel size for vllm")
self.add_argument("--quantization", default=None, type=str, help="quantization setting for vllm")
self.add_argument("--max_gen_toks", default=256, type=int, help="max generation tokens for vllm")
self.add_argument("--swap_space", default=4, type=float, help="swap space (GB) for vllm")
self.add_argument("--max_batch_size", default=None, type=int, help="max batch size for vllm")
self.add_argument("--max_length", default=None, type=int, help="max generation length for vllm")
self.add_argument("--max_model_len", default=None, type=int, help="maximum model sequence length for vllm")
self.add_argument(
"--gpu_memory_utilization", default=0.9, type=float, help="target GPU memory utilization for vllm"
)
self.add_argument("--lora_local_path", default=None, type=str, help="local LoRA path for vllm")


def setup_parser():
Expand Down Expand Up @@ -786,15 +809,16 @@ def eval(args):
if (batch_size := args.eval_bs) is None:
batch_size = "auto:8"
is_gguf_file = False
if os.path.isfile(args.model) and args.model.endswith(".gguf"):
is_gguf_file = True
gguf_file = os.path.basename(args.model)
model = os.path.dirname(args.model)
else:
for file in os.listdir(args.model):
if file.endswith(".gguf"):
is_gguf_file = True
gguf_file = file
if os.path.exists(args.model):
if os.path.isfile(args.model) and args.model.endswith(".gguf"):
is_gguf_file = True
gguf_file = os.path.basename(args.model)
model = os.path.dirname(args.model)
else:
for file in os.listdir(args.model):
if file.endswith(".gguf"):
is_gguf_file = True
gguf_file = file
eval_model_dtype = get_model_dtype(args.eval_model_dtype)
if is_gguf_file:
import torch
Expand Down Expand Up @@ -949,3 +973,51 @@ def eval_task_by_task(
print(make_table(res_all))

print("total eval time:", time.time() - st)


def eval_with_vllm(args):
import time

st = time.time()

from lm_eval import evaluator # pylint: disable=E0401
from lm_eval.models.vllm_causallms import VLLM # pylint: disable=E0401
from lm_eval.utils import make_table # pylint: disable=E0401

device_str, _ = get_device_and_parallelism(args.device)
eval_model_dtype = get_model_dtype(args.eval_model_dtype, "auto")
if (batch_size := args.eval_bs) is None:
batch_size = "auto:8"

vllm_lm = VLLM(
pretrained=args.model,
dtype=eval_model_dtype,
revision=args.revision,
trust_remote_code=not args.disable_trust_remote_code,
tokenizer=args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
tokenizer_revision=args.tokenizer_revision,
add_bos_token=args.add_bos_token,
prefix_token_id=args.prefix_token_id,
tensor_parallel_size=args.tensor_parallel_size,
quantization=args.quantization,
max_gen_toks=args.max_gen_toks,
swap_space=args.swap_space,
batch_size=batch_size,
max_batch_size=args.max_batch_size,
max_length=args.max_length,
max_model_len=args.max_model_len,
seed=args.seed,
gpu_memory_utilization=args.gpu_memory_utilization,
device=device_str,
data_parallel_size=args.data_parallel_size,
lora_local_path=args.lora_local_path,
)
res = evaluator.simple_evaluate(
model=vllm_lm,
tasks=args.tasks,
limit=args.limit,
)

print(make_table(res))
print("evaluation running time=%ds" % (time.time() - st))
2 changes: 2 additions & 0 deletions docs/step_by_step.md
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@ If not explicitly specify '--task', the default value will be used (typically co
~~~
The last format will be used in evaluation if multiple formats have been exported.

Note: To use the vllm backend, please add `--vllm` into the upper command.

### Eval the Quantized model

- AutoRound format
Expand Down
31 changes: 31 additions & 0 deletions test/test_cuda/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
Run `pytest test/test_cuda/test_vllm.py`.
"""

import os
import shutil
import subprocess

import pytest
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
Expand Down Expand Up @@ -43,3 +47,30 @@ def test_auto_round(model):
generated_text = output.outputs[0].text
if "France" in prompt:
assert "Paris" in generated_text


@pytest.mark.parametrize("model", MODELS)
def test_vllm_lm_eval(model):
if shutil.which("auto-round") is None:
pytest.skip("auto-round CLI not available")

env = os.environ.copy()
env["VLLM_SKIP_WARMUP"] = "true"
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

cmd = [
"auto-round",
"--model",
model,
"--eval",
"--tasks",
"lambada_openai",
"--eval_bs",
"8",
"--limit",
"10",
"--vllm",
]

proc = subprocess.run(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
assert proc.returncode == 0, f"auto-round failed (rc={proc.returncode}):\n{proc.stdout}"
Loading