Skip to content

Commit b9245b5

Browse files
authored
cli support for positional arguments model (#979)
Signed-off-by: n1ck-guo <heng.guo@intel.com>
1 parent 3f7bdac commit b9245b5

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

auto_round/__main__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,17 @@
3636
class BasicArgumentParser(argparse.ArgumentParser):
3737
def __init__(self, *args, **kwargs):
3838
super().__init__(*args, **kwargs)
39+
self.add_argument(
40+
"model",
41+
default=None,
42+
nargs="?",
43+
help="Path to the pre-trained model or model identifier from huggingface.co/models. "
44+
"Examples: 'facebook/opt-125m', 'bert-base-uncased', or local path like '/path/to/model'",
45+
)
3946
basic = self.add_argument_group("Basic Arguments")
4047
basic.add_argument(
41-
"--model",
4248
"--model_name",
49+
"--model",
4350
"--model_name_or_path",
4451
default="facebook/opt-125m",
4552
help="Path to the pre-trained model or model identifier from huggingface.co/models. "
@@ -433,6 +440,9 @@ def setup_parser(recipe="default"):
433440

434441

435442
def tune(args):
443+
assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set."
444+
if args.model is None:
445+
args.model = args.model_name
436446
if args.eval_bs is None:
437447
args.eval_bs = "auto"
438448
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@@ -785,6 +795,9 @@ def setup_eval_parser():
785795

786796
def run_eval():
787797
args = setup_eval_parser()
798+
assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set."
799+
if args.model is None:
800+
args.model = args.model_name
788801
if args.eval_task_by_task:
789802
eval_task_by_task(
790803
model=args.model,

auto_round/eval/eval_cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,15 @@ class EvalArgumentParser(argparse.ArgumentParser):
2828
def __init__(self, *args, **kwargs):
2929
super().__init__(*args, **kwargs)
3030
self.add_argument(
31-
"--model",
31+
"model",
32+
default=None,
33+
nargs="?",
34+
help="Path to the pre-trained model or model identifier from huggingface.co/models. "
35+
"Examples: 'facebook/opt-125m', 'bert-base-uncased', or local path like '/path/to/model'",
36+
)
37+
self.add_argument(
3238
"--model_name",
39+
"--model",
3340
"--model_name_or_path",
3441
default="facebook/opt-125m",
3542
help="Path to the pre-trained model or model identifier from huggingface.co/models. "

0 commit comments

Comments
 (0)