Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys

from auto_round.compressors import BaseCompressor
from auto_round.eval.eval_cli import EvalArgumentParser, _eval_init, eval, eval_task_by_task, eval_with_vllm
from auto_round.eval.eval_cli import EvalArgumentParser, _eval_init, eval, eval_task_by_task
from auto_round.schemes import PRESET_SCHEMES
from auto_round.utils import (
clear_memory,
Expand Down
15 changes: 5 additions & 10 deletions auto_round/eval/eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, *args, **kwargs):
self.add_argument(
"--disable_trust_remote_code", action="store_true", help="whether to disable trust_remote_code"
)
self.add_argument("--seed", default=42, type=int, help="random seed")
self.add_argument("--eval_bs", "--bs", "--batch_size", default=None, type=int, help="batch size in evaluation")
self.add_argument("--eval_task_by_task", action="store_true", help="whether to eval task by task.")
self.add_argument(
Expand Down Expand Up @@ -113,15 +114,9 @@ def _eval_init(tasks, model_path, device, disable_trust_remote_code=False, dtype

def eval(args):
if args.eval_backend == "vllm":
try:
assert isinstance(args.model, str), "vllm evaluation only supports model name or path."
eval_with_vllm(args)
return
except Exception as e: # pragma: no cover
print(f"vllm evaluation failed: {e}, fallback to default hf backend evaluation.")
args.eval_backend = "hf"
clear_memory()

assert isinstance(args.model, str), "vllm evaluation only supports model name or path."
eval_with_vllm(args)
return
tasks, model_args, device_str = _eval_init(
args.tasks, args.model, args.device_map, args.disable_trust_remote_code, args.eval_model_dtype
)
Expand Down Expand Up @@ -308,7 +303,7 @@ def eval_with_vllm(args):

st = time.time()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device_str, _ = get_device_and_parallelism(args.device)
device_str, _ = get_device_and_parallelism(args.device_map)
eval_model_dtype = get_model_dtype(args.eval_model_dtype, "auto")
if (batch_size := args.eval_bs) is None:
batch_size = "auto:8"
Expand Down