Skip to content

Commit 233a6c8

Browse files
authored
stop try except and add back missing arguments (#884)
Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent f56804a commit 233a6c8

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

auto_round/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919

2020
from auto_round.compressors import BaseCompressor
21-
from auto_round.eval.eval_cli import EvalArgumentParser, _eval_init, eval, eval_task_by_task, eval_with_vllm
21+
from auto_round.eval.eval_cli import EvalArgumentParser, _eval_init, eval, eval_task_by_task
2222
from auto_round.schemes import PRESET_SCHEMES
2323
from auto_round.utils import (
2424
clear_memory,

auto_round/eval/eval_cli.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, *args, **kwargs):
5454
self.add_argument(
5555
"--disable_trust_remote_code", action="store_true", help="whether to disable trust_remote_code"
5656
)
57+
self.add_argument("--seed", default=42, type=int, help="random seed")
5758
self.add_argument("--eval_bs", "--bs", "--batch_size", default=None, type=int, help="batch size in evaluation")
5859
self.add_argument("--eval_task_by_task", action="store_true", help="whether to eval task by task.")
5960
self.add_argument(
@@ -113,15 +114,9 @@ def _eval_init(tasks, model_path, device, disable_trust_remote_code=False, dtype
113114

114115
def eval(args):
115116
if args.eval_backend == "vllm":
116-
try:
117-
assert isinstance(args.model, str), "vllm evaluation only supports model name or path."
118-
eval_with_vllm(args)
119-
return
120-
except Exception as e: # pragma: no cover
121-
print(f"vllm evaluation failed: {e}, fallback to default hf backend evaluation.")
122-
args.eval_backend = "hf"
123-
clear_memory()
124-
117+
assert isinstance(args.model, str), "vllm evaluation only supports model name or path."
118+
eval_with_vllm(args)
119+
return
125120
tasks, model_args, device_str = _eval_init(
126121
args.tasks, args.model, args.device_map, args.disable_trust_remote_code, args.eval_model_dtype
127122
)
@@ -308,7 +303,7 @@ def eval_with_vllm(args):
308303

309304
st = time.time()
310305
os.environ["TOKENIZERS_PARALLELISM"] = "false"
311-
device_str, _ = get_device_and_parallelism(args.device)
306+
device_str, _ = get_device_and_parallelism(args.device_map)
312307
eval_model_dtype = get_model_dtype(args.eval_model_dtype, "auto")
313308
if (batch_size := args.eval_bs) is None:
314309
batch_size = "auto:8"

0 commit comments

Comments
 (0)