Skip to content

Commit

Permalink
fix cli for vicuna (nod-ai#1666)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Jul 18, 2023
1 parent b013659 commit 8c317e4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
20 changes: 7 additions & 13 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,21 +1564,15 @@ def autocomplete(self, prompt):
config_json=config_json,
weight_group_size=args.weight_group_size,
)
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"

from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
history = []
set_vicuna_model(vic)
while True:
# TODO: Add break condition from user input
user_prompt = input("User: ")
prompt_history = (
prompt_history + "USER:\n" + user_prompt + prologue_prompt
)
prompt = prompt_history.strip()
res_str = vic.generate(prompt, cli=True)
torch.cuda.empty_cache()
gc.collect()
print(
"\n-----\nAssistant: Here's the complete formatted reply:\n",
res_str,
)
prompt_history += f"\n{res_str}\n"
history.append([user_prompt,""])
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]

16 changes: 9 additions & 7 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ def create_prompt(model_name, history):
return msg


def set_vicuna_model(model):
global vicuna_model
vicuna_model = model


# TODO: Make chat reusable for UI and API
def chat(curr_system_message, history, model, device, precision):
global sharded_model
def chat(curr_system_message, history, model, device, precision, cli=True):
global past_key_values
global vicuna_model

global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
print(f"In chat for {model_name}")

if model_name in ["vicuna", "vicuna1p3", "codegen"]:
from apps.language_models.scripts.vicuna import (
Expand Down Expand Up @@ -109,9 +112,8 @@ def chat(curr_system_message, history, model, device, precision):
max_num_tokens=max_toks,
)
prompt = create_prompt(model_name, history)
print("prompt = ", prompt)

for partial_text in vicuna_model.generate(prompt):
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history

Expand Down Expand Up @@ -140,7 +142,7 @@ def chat(curr_system_message, history, model, device, precision):

partial_text = ""
for new_text in words_list:
# print(new_text)
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
Expand Down

0 comments on commit 8c317e4

Please sign in to comment.