Skip to content

Commit

Permalink
Add support for Llama-2-70b for web and cli, and for hf_auth_token
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Jul 20, 2023
1 parent 3662224 commit 03c4d9e
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 26 deletions.
88 changes: 70 additions & 18 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,25 @@
default=128,
help="Group size for per_group weight quantization. Default: 128.",
)
parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV")
parser.add_argument(
"--download_vmfb",
default=False,
action=argparse.BooleanOptionalAction,
help="download vmfb from sharktank, system dependent, YMMV",
)
parser.add_argument(
"--model_name",
type=str,
default="vicuna",
choices=["vicuna", "llama2_7b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)


def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
Expand Down Expand Up @@ -870,6 +888,7 @@ def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_auth_token: str = None,
max_num_tokens=512,
device="cuda",
precision="fp32",
Expand All @@ -883,8 +902,15 @@ def __init__(
download_vmfb=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
if self.model_name == "llama2":
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
elif self.model_name == "llama2_70b":
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.device = device
Expand Down Expand Up @@ -923,11 +949,7 @@ def get_model_path(self, model_number="first", suffix="mlir"):
)

def get_tokenizer(self):
kwargs = {}
if self.model_name == "llama2":
kwargs = {
"use_auth_token": "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"
}
kwargs = {"use_auth_token": self.hf_auth_token}
if self.model_name == "codegen":
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
Expand All @@ -942,9 +964,10 @@ def get_tokenizer(self):
return tokenizer

def get_src_model(self):
kwargs = {"torch_dtype": torch.float}
if self.model_name == "llama2":
kwargs["use_auth_token"] = "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"
kwargs = {
"torch_dtype": torch.float,
"use_auth_token": self.hf_auth_token,
}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path,
**kwargs,
Expand Down Expand Up @@ -1010,6 +1033,7 @@ def compile_first_vicuna(self):
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)

print(f"[DEBUG] generating torchscript graph")
Expand Down Expand Up @@ -1174,6 +1198,7 @@ def compile_second_vicuna(self):
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)

print(f"[DEBUG] generating torchscript graph")
Expand Down Expand Up @@ -1328,7 +1353,8 @@ def compile(self):
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8" and self.download_vmfb
and self.precision == "int8"
and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
Expand All @@ -1350,7 +1376,8 @@ def compile(self):
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8" and self.download_vmfb
and self.precision == "int8"
and self.download_vmfb
):
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
Expand Down Expand Up @@ -1571,7 +1598,8 @@ def autocomplete(self, prompt):
)

vic = UnshardedVicuna(
"vicuna",
model_name=args.model_name,
hf_auth_token=args.hf_auth_token,
device=args.device,
precision=args.precision,
first_vicuna_mlir_path=first_vic_mlir_path,
Expand All @@ -1590,21 +1618,45 @@ def autocomplete(self, prompt):
else:
config_json = None
vic = ShardedVicuna(
"vicuna",
model_name=args.model_name,
device=args.device,
precision=args.precision,
config_json=config_json,
weight_group_size=args.weight_group_size,
)
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
if args.model_name == "vicuna":
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
else:
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
as helpfully as possible, while being safe. Your answers should not
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
content. Please ensure that your responses are socially unbiased and positive
in nature. If a question does not make any sense, or is not factually coherent,
explain why instead of answering something not correct. If you don't know the
answer to a question, please don't share false information."""
prologue_prompt = "ASSISTANT:\n"

from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model

history = []
set_vicuna_model(vic)

model_list = {
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
}
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
history.append([user_prompt,""])
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]

history.append([user_prompt, ""])
history = list(
chat(
system_message,
history,
model=model_list[args.model_name],
device=args.device,
precision=args.precision,
cli=args.cli,
)
)[0]
10 changes: 6 additions & 4 deletions apps/language_models/src/model_wrappers/vicuna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def __init__(
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if model_name == "llama2":
kwargs["use_auth_token"] = "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"
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down Expand Up @@ -54,11 +55,12 @@ def __init__(
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if model_name == "llama2":
kwargs["use_auth_token"] = "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"
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down
7 changes: 7 additions & 0 deletions apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@ def is_valid_file(arg):
help="Load and unload models for low VRAM.",
)

p.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)

##############################################################################
# IREE - Vulkan supported flags
##############################################################################
Expand Down
32 changes: 28 additions & 4 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def user(message, history):
past_key_values = None

model_map = {
"llama2": "meta-llama/Llama-2-7b-chat-hf",
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
Expand All @@ -30,7 +31,16 @@ def user(message, history):

# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2": (
"llama2_7b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
Expand Down Expand Up @@ -67,7 +77,13 @@ def user(message, history):
def create_prompt(model_name, history):
system_message = start_message[model_name]

if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
if model_name in [
"StableLM",
"vicuna",
"vicuna1p3",
"llama2_7b",
"llama2_70b",
]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
Expand Down Expand Up @@ -96,10 +112,17 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))

if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
if model_name in [
"vicuna",
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
from apps.stable_diffusion.src import args

if vicuna_model == 0:
if "cuda" in device:
Expand All @@ -117,6 +140,7 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
Expand Down

0 comments on commit 03c4d9e

Please sign in to comment.