diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 264056aa8f..4d800c2c9a 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -101,7 +101,25 @@ default=128, help="Group size for per_group weight quantization. Default: 128.", ) -parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV") +parser.add_argument( + "--download_vmfb", + default=False, + action=argparse.BooleanOptionalAction, + help="download vmfb from sharktank, system dependent, YMMV", +) +parser.add_argument( + "--model_name", + type=str, + default="vicuna", + choices=["vicuna", "llama2_7b", "llama2_70b"], + help="Specify which model to run.", +) +parser.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="Specify your own huggingface authentication tokens for models like Llama2.", +) def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: @@ -870,6 +888,7 @@ def __init__( self, model_name, hf_model_path="TheBloke/vicuna-7B-1.1-HF", + hf_auth_token: str = None, max_num_tokens=512, device="cuda", precision="fp32", @@ -883,8 +902,15 @@ def __init__( download_vmfb=False, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) - if self.model_name == "llama2": + if "llama2" in self.model_name and hf_auth_token == None: + raise ValueError( + "HF auth token required. Pass it using --hf_auth_token flag." + ) + self.hf_auth_token = hf_auth_token + if self.model_name == "llama2_7b": self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf" + elif self.model_name == "llama2_70b": + self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf" print(f"[DEBUG] hf model name: {self.hf_model_path}") self.max_sequence_length = 256 self.device = device @@ -923,11 +949,7 @@ def get_model_path(self, model_number="first", suffix="mlir"): ) def get_tokenizer(self): - kwargs = {} - if self.model_name == "llama2": - kwargs = { - "use_auth_token": "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" - } + kwargs = {"use_auth_token": self.hf_auth_token} if self.model_name == "codegen": tokenizer = AutoTokenizer.from_pretrained( self.hf_model_path, @@ -942,9 +964,10 @@ def get_tokenizer(self): return tokenizer def get_src_model(self): - kwargs = {"torch_dtype": torch.float} - if self.model_name == "llama2": - kwargs["use_auth_token"] = "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" + kwargs = { + "torch_dtype": torch.float, + "use_auth_token": self.hf_auth_token, + } vicuna_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs, @@ -1010,6 +1033,7 @@ def compile_first_vicuna(self): self.precision, self.weight_group_size, self.model_name, + self.hf_auth_token, ) print(f"[DEBUG] generating torchscript graph") @@ -1174,6 +1198,7 @@ def compile_second_vicuna(self): self.precision, self.weight_group_size, self.model_name, + self.hf_auth_token, ) print(f"[DEBUG] generating torchscript graph") @@ -1328,7 +1353,8 @@ def compile(self): ): if (self.device == "cuda" and self.precision == "fp16") or ( self.device in ["cpu-sync", "cpu-task"] - and self.precision == "int8" and self.download_vmfb + and self.precision == "int8" + and self.download_vmfb ): download_public_file( f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}", @@ -1350,7 +1376,8 @@ def compile(self): ): if (self.device == "cuda" and self.precision == "fp16") or ( self.device in ["cpu-sync", "cpu-task"] - and self.precision == "int8" and self.download_vmfb + and self.precision == "int8" + and self.download_vmfb ): download_public_file( f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}", @@ -1571,7 +1598,8 @@ def autocomplete(self, prompt): ) vic = UnshardedVicuna( - "vicuna", + model_name=args.model_name, + hf_auth_token=args.hf_auth_token, device=args.device, precision=args.precision, first_vicuna_mlir_path=first_vic_mlir_path, @@ -1590,21 +1618,45 @@ def autocomplete(self, prompt): else: config_json = None vic = ShardedVicuna( - "vicuna", + model_name=args.model_name, device=args.device, precision=args.precision, config_json=config_json, weight_group_size=args.weight_group_size, ) - system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" + if args.model_name == "vicuna": + system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" + else: + system_message = """System: You are a helpful, respectful and honest assistant. Always answer " + as helpfully as possible, while being safe. Your answers should not + include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal + content. Please ensure that your responses are socially unbiased and positive + in nature. If a question does not make any sense, or is not factually coherent, + explain why instead of answering something not correct. If you don't know the + answer to a question, please don't share false information.""" prologue_prompt = "ASSISTANT:\n" from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model + history = [] set_vicuna_model(vic) + + model_list = { + "vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF", + "llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf", + "llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf", + } while True: # TODO: Add break condition from user input user_prompt = input("User: ") - history.append([user_prompt,""]) - history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0] - + history.append([user_prompt, ""]) + history = list( + chat( + system_message, + history, + model=model_list[args.model_name], + device=args.device, + precision=args.precision, + cli=args.cli, + ) + )[0] diff --git a/apps/language_models/src/model_wrappers/vicuna_model.py b/apps/language_models/src/model_wrappers/vicuna_model.py index 8533656f25..e930f0cf6a 100644 --- a/apps/language_models/src/model_wrappers/vicuna_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_model.py @@ -12,11 +12,12 @@ def __init__( precision="fp32", weight_group_size=128, model_name="vicuna", + hf_auth_token: str = None, ): super().__init__() kwargs = {"torch_dtype": torch.float32} - if model_name == "llama2": - kwargs["use_auth_token"] = "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" + if "llama2" in model_name: + kwargs["use_auth_token"] = hf_auth_token self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) @@ -54,11 +55,12 @@ def __init__( precision="fp32", weight_group_size=128, model_name="vicuna", + hf_auth_token: str = None, ): super().__init__() kwargs = {"torch_dtype": torch.float32} - if model_name == "llama2": - kwargs["use_auth_token"] = "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" + if "llama2" in model_name: + kwargs["use_auth_token"] = hf_auth_token self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index 483cfa42a3..4db5534ad8 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -400,6 +400,13 @@ def is_valid_file(arg): help="Load and unload models for low VRAM.", ) +p.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="Specify your own huggingface authentication tokens for models like Llama2.", +) + ############################################################################## # IREE - Vulkan supported flags ############################################################################## diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index f125df283d..7e89930d4b 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -21,7 +21,8 @@ def user(message, history): past_key_values = None model_map = { - "llama2": "meta-llama/Llama-2-7b-chat-hf", + "llama2_7b": "meta-llama/Llama-2-7b-chat-hf", + "llama2_70b": "meta-llama/Llama-2-70b-chat-hf", "codegen": "Salesforce/codegen25-7b-multi", "vicuna1p3": "lmsys/vicuna-7b-v1.3", "vicuna": "TheBloke/vicuna-7B-1.1-HF", @@ -30,7 +31,16 @@ def user(message, history): # NOTE: Each `model_name` should have its own start message start_message = { - "llama2": ( + "llama2_7b": ( + "System: You are a helpful, respectful and honest assistant. Always answer " + "as helpfully as possible, while being safe. Your answers should not " + "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal " + "content. Please ensure that your responses are socially unbiased and positive " + "in nature. If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. If you don't know the " + "answer to a question, please don't share false information." + ), + "llama2_70b": ( "System: You are a helpful, respectful and honest assistant. Always answer " "as helpfully as possible, while being safe. Your answers should not " "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal " @@ -67,7 +77,13 @@ def user(message, history): def create_prompt(model_name, history): system_message = start_message[model_name] - if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]: + if model_name in [ + "StableLM", + "vicuna", + "vicuna1p3", + "llama2_7b", + "llama2_70b", + ]: conversation = "".join( [ "".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) @@ -96,10 +112,17 @@ def chat(curr_system_message, history, model, device, precision, cli=True): global vicuna_model model_name, model_path = list(map(str.strip, model.split("=>"))) - if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]: + if model_name in [ + "vicuna", + "vicuna1p3", + "codegen", + "llama2_7b", + "llama2_70b", + ]: from apps.language_models.scripts.vicuna import ( UnshardedVicuna, ) + from apps.stable_diffusion.src import args if vicuna_model == 0: if "cuda" in device: @@ -117,6 +140,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True): vicuna_model = UnshardedVicuna( model_name, hf_model_path=model_path, + hf_auth_token=args.hf_auth_token, device=device, precision=precision, max_num_tokens=max_toks,