diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index 9354e01d5c..3f5fe2024a 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -36,7 +36,7 @@ def normal_text(args): t = get_tokenizer(args.model_path, trust_remote_code=True) m = AutoModelForCausalLM.from_pretrained( args.model_path, - torch_dtype=torch.float16, + torch_dtype=args.dtype, low_cpu_mem_usage=True, device_map="auto", trust_remote_code=True, @@ -47,7 +47,7 @@ def normal_text(args): "The capital of the United Kindom is", "Today is a sunny day and I like", ] - max_new_tokens = 16 + max_new_tokens = args.max_new_tokens torch.cuda.set_device(0) @@ -104,6 +104,16 @@ def synthetic_tokens(args): default="TinyLlama/TinyLlama-1.1B-Chat-v0.4", # default="meta-llama/Llama-2-7b-chat-hf", ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=16) + + parser.add_argument( + "--dtype", + type=str, + default="float16") + args = parser.parse_args() normal_text(args)