Skip to content

Commit 04fdb94

Browse files
committed
support lm_eval vllm backend
Signed-off-by: xinhe3 <xinhe3@habana.ai>
1 parent 30d1061 commit 04fdb94

File tree

4 files changed

+121
-9
lines changed

4 files changed

+121
-9
lines changed

auto_round/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def run_eval():
2121
elif "--lmms" in sys.argv:
2222
sys.argv.remove("--lmms")
2323
run_lmms()
24+
elif "--vllm" in sys.argv:
25+
sys.argv.remove("--vllm")
26+
from auto_round.script.llm import eval_with_vllm, setup_eval_parser
27+
28+
args = setup_eval_parser()
29+
eval_with_vllm(args)
2430
else:
2531
from auto_round.script.llm import setup_eval_parser
2632

auto_round/script/llm.py

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
set_cuda_visible_devices,
4343
)
4444

45+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
46+
4547

4648
class BasicArgumentParser(argparse.ArgumentParser):
4749

@@ -322,6 +324,28 @@ def __init__(self, *args, **kwargs):
322324
help="Limit the number of examples per task. "
323325
"If <1, limit is a percentage of the total number of examples.",
324326
)
327+
# vllm related arguments
328+
self.add_argument("--revision", default=None, type=str, help="model revision for vllm")
329+
self.add_argument("--tokenizer", default=None, type=str, help="tokenizer to use with vllm")
330+
self.add_argument(
331+
"--tokenizer_mode", default="auto", type=str, help="tokenizer mode for vllm (e.g. auto/fast/slow)"
332+
)
333+
self.add_argument("--tokenizer_revision", default=None, type=str, help="tokenizer revision for vllm")
334+
self.add_argument("--add_bos_token", action="store_true", help="add BOS token when using vllm")
335+
self.add_argument("--prefix_token_id", default=None, type=int, help="prefix token id for vllm")
336+
self.add_argument("--tensor_parallel_size", default=1, type=int, help="tensor parallel size for vllm")
337+
self.add_argument("--data_parallel_size", default=1, type=int, help="data parallel size for vllm")
338+
self.add_argument("--quantization", default=None, type=str, help="quantization setting for vllm")
339+
self.add_argument("--max_gen_toks", default=256, type=int, help="max generation tokens for vllm")
340+
self.add_argument("--swap_space", default=4, type=float, help="swap space (GB) for vllm")
341+
self.add_argument("--max_batch_size", default=None, type=int, help="max batch size for vllm")
342+
self.add_argument("--max_length", default=None, type=int, help="max generation length for vllm")
343+
self.add_argument("--max_model_len", default=None, type=int, help="maximum model sequence length for vllm")
344+
self.add_argument("--seed", default=1234, type=int, help="random seed")
345+
self.add_argument(
346+
"--gpu_memory_utilization", default=0.9, type=float, help="target GPU memory utilization for vllm"
347+
)
348+
self.add_argument("--lora_local_path", default=None, type=str, help="local LoRA path for vllm")
325349

326350

327351
def setup_parser():
@@ -786,15 +810,16 @@ def eval(args):
786810
if (batch_size := args.eval_bs) is None:
787811
batch_size = "auto:8"
788812
is_gguf_file = False
789-
if os.path.isfile(args.model) and args.model.endswith(".gguf"):
790-
is_gguf_file = True
791-
gguf_file = os.path.basename(args.model)
792-
model = os.path.dirname(args.model)
793-
else:
794-
for file in os.listdir(args.model):
795-
if file.endswith(".gguf"):
796-
is_gguf_file = True
797-
gguf_file = file
813+
if os.path.exists(args.model):
814+
if os.path.isfile(args.model) and args.model.endswith(".gguf"):
815+
is_gguf_file = True
816+
gguf_file = os.path.basename(args.model)
817+
model = os.path.dirname(args.model)
818+
else:
819+
for file in os.listdir(args.model):
820+
if file.endswith(".gguf"):
821+
is_gguf_file = True
822+
gguf_file = file
798823
eval_model_dtype = get_model_dtype(args.eval_model_dtype)
799824
if is_gguf_file:
800825
import torch
@@ -949,3 +974,51 @@ def eval_task_by_task(
949974
print(make_table(res_all))
950975

951976
print("total eval time:", time.time() - st)
977+
978+
979+
def eval_with_vllm(args):
980+
import time
981+
982+
st = time.time()
983+
984+
from lm_eval import evaluator # pylint: disable=E0401
985+
from lm_eval.models.vllm_causallms import VLLM # pylint: disable=E0401
986+
from lm_eval.utils import make_table # pylint: disable=E0401
987+
988+
device_str, _ = get_device_and_parallelism(args.device)
989+
eval_model_dtype = get_model_dtype(args.eval_model_dtype, "auto")
990+
if (batch_size := args.eval_bs) is None:
991+
batch_size = "auto:8"
992+
993+
vllm_lm = VLLM(
994+
pretrained=args.model,
995+
dtype=eval_model_dtype,
996+
revision=args.revision,
997+
trust_remote_code=not args.disable_trust_remote_code,
998+
tokenizer=args.tokenizer,
999+
tokenizer_mode=args.tokenizer_mode,
1000+
tokenizer_revision=args.tokenizer_revision,
1001+
add_bos_token=args.add_bos_token,
1002+
prefix_token_id=args.prefix_token_id,
1003+
tensor_parallel_size=args.tensor_parallel_size,
1004+
quantization=args.quantization,
1005+
max_gen_toks=args.max_gen_toks,
1006+
swap_space=args.swap_space,
1007+
batch_size=batch_size,
1008+
max_batch_size=args.max_batch_size,
1009+
max_length=args.max_length,
1010+
max_model_len=args.max_model_len,
1011+
seed=args.seed,
1012+
gpu_memory_utilization=args.gpu_memory_utilization,
1013+
device=device_str,
1014+
data_parallel_size=args.data_parallel_size,
1015+
lora_local_path=args.lora_local_path,
1016+
)
1017+
res = evaluator.simple_evaluate(
1018+
model=vllm_lm,
1019+
tasks=args.tasks,
1020+
limit=args.limit,
1021+
)
1022+
1023+
print(make_table(res))
1024+
print("evaluation running time=%ds" % (time.time() - st))

docs/step_by_step.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,8 @@ If not explicitly specify '--task', the default value will be used (typically co
617617
~~~
618618
The last format will be used in evaluation if multiple formats have been exported.
619619

620+
Note: To use the vllm backend, please add `--vllm` into the upper command.
621+
620622
### Eval the Quantized model
621623

622624
- AutoRound format

test/test_cuda/test_vllm.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
Run `pytest test/test_cuda/test_vllm.py`.
88
"""
99

10+
import os
11+
import shutil
12+
import subprocess
13+
1014
import pytest
1115
from vllm import LLM, SamplingParams
1216
from vllm.platforms import current_platform
@@ -43,3 +47,30 @@ def test_auto_round(model):
4347
generated_text = output.outputs[0].text
4448
if "France" in prompt:
4549
assert "Paris" in generated_text
50+
51+
52+
@pytest.mark.parametrize("model", MODELS)
53+
def test_vllm_lm_eval(model):
54+
if shutil.which("auto-round") is None:
55+
pytest.skip("auto-round CLI not available")
56+
57+
env = os.environ.copy()
58+
env["VLLM_SKIP_WARMUP"] = "true"
59+
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
60+
61+
cmd = [
62+
"auto-round",
63+
"--model",
64+
model,
65+
"--eval",
66+
"--tasks",
67+
"lambada_openai",
68+
"--eval_bs",
69+
"8",
70+
"--limit",
71+
"10",
72+
"--vllm",
73+
]
74+
75+
proc = subprocess.run(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
76+
assert proc.returncode == 0, f"auto-round failed (rc={proc.returncode}):\n{proc.stdout}"

0 commit comments

Comments
 (0)